Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)  0.20.0
Performance library for Deep Learning
mkldnn.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef MKLDNN_HPP
18 #define MKLDNN_HPP
19 
20 #ifndef DOXYGEN_SHOULD_SKIP_THIS
21 #include <stdlib.h>
22 #include <memory>
23 #include <vector>
24 #include <algorithm>
25 #include <iterator>
26 #include <string>
27 
28 #include "mkldnn.h"
29 #endif
30 
31 namespace mkldnn {
32 
35 
38 
40 template <typename T> class handle_traits {};
41 
55 template <typename T, typename traits=handle_traits<T>> class handle {
56 private:
57  std::shared_ptr<typename std::remove_pointer<T>::type> _data;
58  handle(const handle &&) = delete;
59  handle &operator=(const handle &&other) = delete;
60 protected:
61  bool operator==(const T other) const { return other == _data.get(); }
62  bool operator!=(const T other) const { return !(*this == other); }
63 public:
67  handle(T t = 0, bool weak = false): _data(0) {
68  reset(t, weak);
69  }
70 
71  handle(const handle &other): _data(other._data) {}
72  handle &operator=(const handle &other) {
73  _data = other._data;
74  return *this;
75  }
79  void reset(T t, bool weak = false) {
80  auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
81  _data.reset(t, weak ? dummy_destructor : traits::destructor);
82  }
83 
85  T get() const { return _data.get(); }
86 
87  bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
88  bool operator!=(const handle &other) const { return !(*this == other); }
89 };
90 
91 #ifndef DOXYGEN_SHOULD_SKIP_THIS
92 template <> struct handle_traits<mkldnn_primitive_desc_t> {
93  static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
94 };
95 
96 template <> struct handle_traits<mkldnn_primitive_t> {
97  static constexpr auto destructor = &mkldnn_primitive_destroy;
98 };
99 
100 template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
101  static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
102 };
103 #endif
104 
106 class primitive: public handle<mkldnn_primitive_t> {
107  friend struct error;
108  friend struct stream;
109  friend class primitive_at;
110  using handle::handle;
111 public:
113  enum class kind {
114  undefined_primitive = mkldnn_undefined_primitive,
116  view = mkldnn_view,
119  concat_inplace = mkldnn_concat_inplace,
120  sum = mkldnn_sum,
121  convolution = mkldnn_convolution,
122  deconvolution = mkldnn_deconvolution,
123  shuffle = mkldnn_shuffle,
124  eltwise = mkldnn_eltwise,
125  softmax = mkldnn_softmax,
126  pooling = mkldnn_pooling,
127  lrn = mkldnn_lrn,
128  batch_normalization = mkldnn_batch_normalization,
129  inner_product = mkldnn_inner_product,
130  rnn = mkldnn_rnn,
131  };
132 
134  struct at {
142 
143  at(const primitive &aprimitive, size_t at = 0)
144  : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
146  inline operator primitive() const;
147  };
148 
150  inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
151  // TODO: use the C++ API wrapper structure.
152 };
153 
155  return static_cast<mkldnn_primitive_kind_t>(akind);
156 }
161 struct error: public std::exception {
163  std::string message;
165 
172 
173  error(mkldnn_status_t astatus, std::string amessage,
174  mkldnn_primitive_t aerror_primitive = 0)
175  : status(astatus)
176  , message(amessage)
177  , error_primitive(aerror_primitive, true)
178  {}
179 
187 
188  static void wrap_c_api(mkldnn_status_t status,
189  const std::string &message,
190  mkldnn_primitive_t *error_primitive = 0)
191  {
192  if (status != mkldnn_success) {
193  if (nullptr != error_primitive)
194  throw error(status, message, *error_primitive);
195  else
196  throw error(status, message, nullptr);
197  }
198  }
199 };
200 
201 inline primitive::at::operator primitive() const {
204  mkldnn_primitive_get_output(data.primitive,
205  data.output_index, &output),
206  "could not get an output primitive");
207  return primitive(const_cast<mkldnn_primitive_t>(output), true);
208 }
209 
213  "could not get primitive descriptor by primitive");
214  return pd;
215 }
217 
222 
226 };
227 
229  return static_cast<mkldnn_round_mode_t>(mode);
230 }
231 
234 };
235 
237  return static_cast<mkldnn_padding_kind_t>(kind);
238 }
239 
240 enum prop_kind {
249 };
250 
252  return static_cast<mkldnn_prop_kind_t>(kind);
253 }
254 
255 enum algorithm {
283 };
284 
286  return static_cast<mkldnn_alg_kind_t>(aalgorithm);
287 }
288 
293 };
294 
296  batch_normalization_flag aflag) {
297  return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
298 }
299 
306 };
307 
309  return static_cast<mkldnn_rnn_direction_t>(adir);
310 }
311 
312 enum query {
314 
317 
320 
323 
325 
338 
348 };
349 
351  return static_cast<mkldnn_query_t>(aquery);
352 }
353 
355 
361 
362 #ifndef DOXYGEN_SHOULD_SKIP_THIS
363 template <> struct handle_traits<mkldnn_post_ops_t> {
364  static constexpr auto destructor = &mkldnn_post_ops_destroy;
365 };
366 #endif
367 
368 struct post_ops: public handle<mkldnn_post_ops_t> {
370  mkldnn_post_ops_t result;
372  "could not create post operation sequence");
373  reset(result);
374  }
375 
376  int len() const { return mkldnn_post_ops_len(get()); }
377 
378  primitive::kind kind(int index) const {
380  index < len() ? mkldnn_success : mkldnn_invalid_arguments,
381  "post_ops index is out of range");
382  return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
383  index));
384  }
385 
386  void append_sum(float scale = 1.) {
388  "could not append sum");
389  }
390 
391  void get_params_sum(int index, float &scale) const {
393  "could not get sum params");
394  }
395 
396  void append_eltwise(float scale, algorithm alg, float alpha,
397  float beta) {
399  convert_to_c(alg), alpha, beta),
400  "could not append eltwise");
401  }
402 
403  void get_params_eltwise(int index, float &scale, algorithm &alg,
404  float &alpha, float &beta) const {
405  mkldnn_alg_kind_t c_alg;
407  &scale, &c_alg, &alpha, &beta),
408  "could not get eltwise params");
409  alg = static_cast<algorithm>(c_alg);
410  }
411 };
412 
413 #ifndef DOXYGEN_SHOULD_SKIP_THIS
414 template <> struct handle_traits<mkldnn_primitive_attr_t> {
415  static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
416 };
417 #endif
418 
419 struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
421  mkldnn_primitive_attr_t result;
423  "could not create a primitive attr");
424  reset(result);
425  }
426 
428  mkldnn_round_mode_t result;
430  get(), &result), "could not get int output round mode");
431  return round_mode(result);
432  }
433 
436  get(), mkldnn::convert_to_c(mode)),
437  "could not set int output round mode");
438  }
439 
440  void get_output_scales(int &mask, std::vector<float> &scales) const
441  {
442  int count, c_mask;
443  const float *c_scales;
445  &count, &c_mask, &c_scales),
446  "could not get int output scales");
447  scales.resize(count);
448 
449  mask = c_mask;
450  for (int c = 0; c < count; ++c)
451  scales[c] = c_scales[c];
452  }
453 
454  void set_output_scales(int mask, const std::vector<float> &scales)
455  {
457  (int)scales.size(), mask, &scales[0]),
458  "could not set int output scales");
459  }
460 
461  const post_ops get_post_ops() const {
462  post_ops result;
463  const_mkldnn_post_ops_t c_result;
465  "could not get post operation sequence");
466  result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
467  return result;
468  }
469 
470  void set_post_ops(post_ops ops) {
472  "could not set post operation sequence");
473  }
474 
475  void set_rnn_data_qparams(const float scale, const float shift)
476  {
478  scale, shift), "could not set rnn data int scale/shift");
479  }
480 
481  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
482  {
484  (int)scales.size(), mask, &scales[0]),
485  "could not set rnn weights int scales");
486  }
487 };
488 
490 
496 
497 #ifndef DOXYGEN_SHOULD_SKIP_THIS
498 template <> struct handle_traits<mkldnn_engine_t> {
499  static constexpr auto destructor = &mkldnn_engine_destroy;
500 };
501 #endif
502 
504 struct engine: public handle<mkldnn_engine_t> {
505  friend class primitive;
506  // gcc bug??? using handle::handle;
507 
509  enum kind {
513  cpu = mkldnn_cpu,
514  };
515 
519 
520  static size_t get_count(kind akind) {
521  return mkldnn_engine_get_count(convert_to_c(akind));
522  }
523 
529 
530  engine(kind akind, size_t index) {
531  mkldnn_engine_t aengine;
533  mkldnn_engine_create(&aengine,
534  convert_to_c(akind), index),
535  "could not create an engine");
536  reset(aengine);
537  }
538 
539  explicit engine(const mkldnn_engine_t& aengine)
540  : handle(aengine, true) {}
541 
543  mkldnn_engine_t engine_q;
546  mkldnn::convert_to_c(eengine), 0, &engine_q),
547  "could not get engine from primitive_desc");
548  reset(engine_q, true);
549  }
550 
551  template <class primitive_desc>
552  static engine query(const primitive_desc &pd) {
553  mkldnn_engine_t engine_q;
556  mkldnn::convert_to_c(eengine), 0, &engine_q),
557  "could not get engine from primitive_desc");
558 
559  return engine(engine_q);
560  }
561 
562 private:
563  static mkldnn_engine_kind_t convert_to_c(kind akind) {
564  return static_cast<mkldnn_engine_kind_t>(akind);
565  }
566 };
567 
569 
572 
578 
580 struct memory: public primitive {
581  private:
582  std::shared_ptr<char> _handle;
583 
584  public:
585  typedef std::vector<std::remove_extent<mkldnn_dims_t>::type> dims;
586 
587  template <typename T> static void validate_dims(std::vector<T> v) {
588  if (v.size() > TENSOR_MAX_DIMS)
590  "invalid dimensions");
591  }
592 
595  enum data_type {
597  f32 = mkldnn_f32,
598  s32 = mkldnn_s32,
599  bf16 = mkldnn_bf16,
600  s16 = mkldnn_s16,
601  s8 = mkldnn_s8,
602  u8 = mkldnn_u8,
603  };
604 
607  enum format {
608  format_undef = mkldnn_format_undef,
609  any = mkldnn_any,
610  blocked = mkldnn_blocked,
611  x = mkldnn_x,
612  nc = mkldnn_nc,
613  ncw = mkldnn_ncw,
614  nwc = mkldnn_nwc,
615  nCw16c = mkldnn_nCw16c,
616  nchw = mkldnn_nchw,
617  nhwc = mkldnn_nhwc,
618  chwn = mkldnn_chwn,
619  nCw4c = mkldnn_nCw4c,
620  nCw8c = mkldnn_nCw8c,
621  nChw4c = mkldnn_nChw4c,
622  nChw8c = mkldnn_nChw8c,
623  nChw16c = mkldnn_nChw16c,
624  ncdhw = mkldnn_ncdhw,
625  ndhwc = mkldnn_ndhwc,
626  nCdhw4c = mkldnn_nCdhw4c,
627  nCdhw8c = mkldnn_nCdhw8c,
628  nCdhw16c = mkldnn_nCdhw16c,
629  oi = mkldnn_oi,
630  io = mkldnn_io,
631  oiw = mkldnn_oiw,
632  wio = mkldnn_wio,
633  Owi4o = mkldnn_Owi4o,
634  OIw4i4o = mkldnn_OIw4i4o,
635  Owi8o = mkldnn_Owi8o,
636  OIw8o8i = mkldnn_OIw8o8i,
637  OIw8i8o = mkldnn_OIw8i8o,
638  OIw16i16o = mkldnn_OIw16i16o,
639  OIw16o16i = mkldnn_OIw16o16i,
640  Oiw4o = mkldnn_Oiw4o,
641  Oiw16o = mkldnn_Oiw16o,
642  Owi16o = mkldnn_Owi16o,
643  OIw8i16o2i = mkldnn_OIw8i16o2i,
644  OIw8o16i2o = mkldnn_OIw8o16i2o,
645  IOw8o16i2o = mkldnn_IOw8o16i2o,
646  IOw16o16i = mkldnn_IOw16o16i,
647  OIw4i16o4i = mkldnn_OIw4i16o4i,
648  OIw4i16o4i_s8s8 = mkldnn_OIw4i16o4i_s8s8,
649  oihw = mkldnn_oihw,
650  ihwo = mkldnn_ihwo,
651  hwio = mkldnn_hwio,
652  iohw = mkldnn_iohw,
653  hwio_s8s8 = mkldnn_hwio_s8s8,
654  dhwio = mkldnn_dhwio,
655  oidhw = mkldnn_oidhw,
656  OIdhw4i4o = mkldnn_OIdhw4i4o,
657  Odhwi4o = mkldnn_Odhwi4o,
658  OIdhw8i8o = mkldnn_OIdhw8i8o,
659  OIdhw8o8i = mkldnn_OIdhw8o8i,
660  Odhwi8o = mkldnn_Odhwi8o,
661  OIdhw16i16o = mkldnn_OIdhw16i16o,
662  OIdhw16o16i = mkldnn_OIdhw16o16i,
663  Oidhw4o = mkldnn_Oidhw4o,
664  Oidhw16o = mkldnn_Oidhw16o,
665  Odhwi16o = mkldnn_Odhwi16o,
666  oIhw8i = mkldnn_oIhw8i,
667  oIhw16i = mkldnn_oIhw16i,
668  oIdhw8i = mkldnn_oIdhw8i,
669  oIdhw16i = mkldnn_oIdhw16i,
670  OIhw4i4o = mkldnn_OIhw4i4o,
671  OIhw8i8o = mkldnn_OIhw8i8o,
672  OIhw16i16o = mkldnn_OIhw16i16o,
673  OIhw8o8i = mkldnn_OIhw8o8i,
674  OIhw16o16i = mkldnn_OIhw16o16i,
675  IOhw16o16i = mkldnn_IOhw16o16i,
676  OIhw8i16o2i = mkldnn_OIhw8i16o2i,
677  IOhw8i16o2i = mkldnn_IOhw8i16o2i,
678  OIhw8o16i2o = mkldnn_OIhw8o16i2o,
679  IOhw8o16i2o = mkldnn_IOhw8o16i2o,
680  OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
681  OIdhw8o16i2o = mkldnn_OIdhw8o16i2o,
682  IOdhw8o16i2o = mkldnn_IOdhw8o16i2o,
683  OIhw4i16o4i = mkldnn_OIhw4i16o4i,
684  OIhw4i16o4i_s8s8 = mkldnn_OIhw4i16o4i_s8s8,
685  Oihw8o = mkldnn_Oihw8o,
686  Oihw4o = mkldnn_Oihw4o,
687  Oihw16o = mkldnn_Oihw16o,
688  Ohwi8o = mkldnn_Ohwi8o,
689  Ohwi4o = mkldnn_Ohwi4o,
690  Ohwi16o = mkldnn_Ohwi16o,
691  OhIw16o4i = mkldnn_OhIw16o4i,
692  goiw = mkldnn_goiw,
693  gOwi4o = mkldnn_gOwi4o,
694  gOIw4i4o = mkldnn_gOIw4i4o,
695  gOwi8o = mkldnn_gOwi8o,
696  gOIw8o8i = mkldnn_gOIw8o8i,
697  gOIw8i8o = mkldnn_gOIw8i8o,
698  gOIw16i16o = mkldnn_gOIw16i16o,
699  gOIw16o16i = mkldnn_gOIw16o16i,
700  gOiw4o = mkldnn_gOiw4o,
701  gOiw16o = mkldnn_gOiw16o,
702  gOwi16o = mkldnn_gOwi16o,
703  gIOw16o16i = mkldnn_gIOw16o16i,
704  gOIw8i16o2i = mkldnn_gOIw8i16o2i,
705  gOIw8o16i2o = mkldnn_gOIw8o16i2o,
706  gIOw8o16i2o = mkldnn_gIOw8o16i2o,
707  gOIw4i16o4i = mkldnn_gOIw4i16o4i,
708  gOIw4i16o4i_s8s8 = mkldnn_gOIw4i16o4i_s8s8,
709  goihw = mkldnn_goihw,
710  hwigo = mkldnn_hwigo,
711  giohw = mkldnn_giohw,
712  hwigo_s8s8 = mkldnn_hwigo_s8s8,
713  gOIdhw4i4o = mkldnn_gOIdhw4i4o,
714  gOdhwi4o = mkldnn_gOdhwi4o,
715  gOIdhw8i8o = mkldnn_gOIdhw8i8o,
716  gOIdhw8o8i = mkldnn_gOIdhw8o8i,
717  gOdhwi8o = mkldnn_gOdhwi8o,
718  gOIhw4i4o = mkldnn_gOIhw4i4o,
719  gOIhw8i8o = mkldnn_gOIhw8i8o,
720  gOIhw16i16o = mkldnn_gOIhw16i16o,
721  gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
722  gIOhw8i16o2i = mkldnn_gIOhw8i16o2i,
723  gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
724  gIOhw8o16i2o = mkldnn_gIOhw8o16i2o,
725  gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
726  gOIdhw8o16i2o = mkldnn_gOIdhw8o16i2o,
727  gIOdhw8o16i2o = mkldnn_gIOdhw8o16i2o,
728  gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
729  gOIhw4i16o4i_s8s8 = mkldnn_gOIhw4i16o4i_s8s8,
730  gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
731  gOIhw2i8o4i_s8s8 = mkldnn_gOIhw2i8o4i_s8s8,
732  gOihw8o = mkldnn_gOihw8o,
733  gOihw4o = mkldnn_gOihw4o,
734  gOihw16o = mkldnn_gOihw16o,
735  gOhwi4o = mkldnn_gOhwi4o,
736  gOhwi8o = mkldnn_gOhwi8o,
737  gOhwi16o = mkldnn_gOhwi16o,
738  Goihw8g = mkldnn_Goihw8g,
739  Goiw16g = mkldnn_Goiw16g,
740  Goiw16g_s8s8 = mkldnn_Goiw16g_s8s8,
741  Goihw16g = mkldnn_Goihw16g,
742  Goihw16g_s8s8 = mkldnn_Goihw16g_s8s8,
743  gOIhw4o4i = mkldnn_gOIhw4o4i,
744  gOIhw4o4i_s8s8 = mkldnn_gOIhw4o4i_s8s8,
745  gOIhw8o8i = mkldnn_gOIhw8o8i,
746  gOIhw16o16i = mkldnn_gOIhw16o16i,
747  gIOhw16o16i = mkldnn_gIOhw16o16i,
748  gOhIw16o4i = mkldnn_gOhIw16o4i,
749  goidhw = mkldnn_goidhw,
750  gOIdhw16i16o = mkldnn_gOIdhw16i16o,
751  gOIdhw16o16i = mkldnn_gOIdhw16o16i,
752  gOidhw4o = mkldnn_gOidhw4o,
753  gOidhw16o = mkldnn_gOidhw16o,
754  gOdhwi16o = mkldnn_gOdhwi16o,
755  ntc = mkldnn_ntc,
756  tnc = mkldnn_tnc,
757  ldsnc = mkldnn_ldsnc,
758  ldigo = mkldnn_ldigo,
759  ldgoi = mkldnn_ldgoi,
760  ldgo = mkldnn_ldgo,
761  rnn_packed = mkldnn_rnn_packed,
762  wino_fmt = mkldnn_wino_fmt,
763  format_last = mkldnn_format_last,
764  };
765 
767  struct desc {
768  friend struct memory;
771 
777  desc(dims adims, data_type adata_type,
778  format aformat) {
779  validate_dims(adims);
781  mkldnn_memory_desc_init(&data, (int)adims.size(),
782  adims.size() == 0 ? nullptr : &adims[0],
783  convert_to_c(adata_type), convert_to_c(aformat)),
784  "could not initialize a memory descriptor");
785  }
786 
790  desc(const mkldnn_memory_desc_t &adata): data(adata) {}
791  };
792 
794  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
795  friend struct memory;
796 
797  // TODO: make private
799 
801  primitive_desc(const desc &adesc, const engine &aengine) {
802  mkldnn_primitive_desc_t result;
805  &adesc.data, aengine.get()),
806  "could not initialize a memory primitive descriptor");
807  reset(result);
808  }
809 
813  return memory::desc(*memory_d); }
814 
817  size_t get_size() const {
819  }
820 
821  bool operator==(const primitive_desc &other) const {
822  return (0 == mkldnn_memory_primitive_desc_equal(get(),
823  other.get())) ? false : true;
824  }
825 
826  bool operator!=(const primitive_desc &other) const {
827  return !operator==(other);
828  }
829 
830  engine get_engine() { return engine::query(*this); }
831  };
832 
836  memory(const primitive &aprimitive): primitive(aprimitive) {}
840  memory(const primitive_desc &adesc) {
841  mkldnn_primitive_t result;
843  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
844  "could not create a memory primitive");
845  reset(result);
846  auto _malloc = [](size_t size, int alignment) {
847  void *ptr;
848 #ifdef _WIN32
849  ptr = _aligned_malloc(size, alignment);
850  int rc = ((ptr)? 0 : errno);
851 #else
852  int rc = ::posix_memalign(&ptr, alignment, size);
853 #endif /* _WIN32 */
854  return (rc == 0) ? (char*)ptr : nullptr;
855  };
856  auto _free = [](char* p) {
857 #ifdef _WIN32
858  _aligned_free((void*)p);
859 #else
860  ::free((void*)p);
861 #endif /* _WIN32 */
862  };
863  _handle.reset(_malloc(adesc.get_size(), 4096), _free);
864  set_data_handle(_handle.get());
865  }
866 
867  memory(const primitive_desc &adesc, void *ahandle) {
868  mkldnn_primitive_t result;
870  mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
871  "could not create a memory primitive");
872  reset(result);
873  set_data_handle(ahandle);
874  }
875 
878  primitive_desc adesc;
881  &cdesc),
882  "could not get primitive descriptor from a memory primitive");
883  /* FIXME: no const_cast should be here */
884  adesc.reset(const_cast<mkldnn_primitive_desc_t>(cdesc), true);
885  return adesc;
886  }
887 
890  inline void *get_data_handle() const {
891  void *handle;
893  "could not get native handle");
894  return handle;
895  }
896 
897  inline void set_data_handle(void *handle) const {
899  "could not set native handle");
900  }
901 
902  // Must go away or be private:
904  return static_cast<mkldnn_data_type_t>(adata_type);
905  }
907  return static_cast<mkldnn_memory_format_t>(aformat);
908  }
909 };
910 
912  auto zero = mkldnn_memory_desc_t();
913  zero.primitive_kind = mkldnn_memory;
914  return memory::desc(zero);
915 }
916 
917 inline memory null_memory(engine eng) {
919  return memory({zero, eng}, nullptr);
920 }
921 
923  &aprimitive_desc, int n_inputs, int n_outputs,
924  const std::string &prim_name) {
925  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
926  aprimitive_desc, mkldnn_query_num_of_inputs_s32, 0);
927  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
928  aprimitive_desc, mkldnn_query_num_of_outputs_s32, 0);
929  if (n_outputs_expected > n_outputs ) {
930  std::string message = "could not create " + prim_name +
931  " primitive, not enought output parameters";
932  throw error(mkldnn_invalid_arguments, message, nullptr);
933  }
934  if (n_inputs_expected > n_inputs ) {
935  std::string message = "could not create " + prim_name +
936  " primitive, not enought input parameters";
937  throw error(mkldnn_invalid_arguments, message, nullptr);
938  }
939 }
940 
941 
942 inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
943  const_mkldnn_primitive_desc_t aprimitive_pd;
944  mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
946  aprimitive_pd);
947 
948  return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
949 }
950 
952  return a == memory::convert_to_c(b);
953 }
955  return !(a == b);
956 }
958  return b == a;
959 }
961  return !(a == b);
962 }
963 
965  return a == memory::convert_to_c(b);
966 }
968  return !(a == b);
969 }
971  return b == a;
972 }
974  return !(a == b);
975 }
976 
978 
984 
985 struct reorder : public primitive {
986  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
988  const memory::primitive_desc &output) {
989  mkldnn_primitive_desc_t result;
991  &result, input.get(), output.get()),
992  "could not create a reorder primitive descriptor");
993  reset(result);
994  }
995 
997  const memory::primitive_desc &output,
998  const primitive_attr &aattr) {
999  mkldnn_primitive_desc_t result;
1001  &result, input.get(), output.get(), aattr.get()),
1002  "could not create a reorder primitive descriptor");
1003  reset(result);
1004  }
1005 
1006  engine get_engine() { return engine::query(*this); }
1007  };
1008 
1009  reorder(const primitive_desc &aprimitive_desc,
1010  const primitive::at &input, const memory &output) {
1011  mkldnn_primitive_t result;
1012  mkldnn_primitive_at_t inputs[] = { input.data };
1013  const_mkldnn_primitive_t outputs[] = { output.get() };
1015  aprimitive_desc.get(), inputs, outputs),
1016  "could not create a reorder primitive");
1017  reset(result);
1018  }
1019 
1020  reorder(const primitive::at &input, const memory &output) {
1021  auto input_mpd = memory(input).get_primitive_desc();
1022  auto output_mpd = output.get_primitive_desc();
1023 
1024  auto reorder_d = primitive_desc(input_mpd, output_mpd);
1025 
1026  mkldnn_primitive_t result;
1027  mkldnn_primitive_at_t inputs[] = { input.data };
1028  const_mkldnn_primitive_t outputs[] = { output.get() };
1030  reorder_d.get(), inputs, outputs),
1031  "could not create a reorder primitive");
1032  reset(result);
1033  }
1034 };
1035 
1037 
1043 
1044 struct view : public primitive {
1045  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1047  memory::dims offsets) {
1048  mkldnn_primitive_desc_t result;
1049 
1051  &result, input.get(), &dims[0], &offsets[0]),
1052  "could not create a view primitive descriptor");
1053  reset(result);
1054  }
1055 
1057  memory::primitive_desc adesc;
1058  mkldnn_primitive_desc_t cdesc;
1059  const_mkldnn_primitive_desc_t const_cdesc =
1063  const_cdesc),
1064  "could not clone a dst primitive descriptor");
1065  adesc.reset(cdesc);
1066  return adesc;
1067  }
1068 
1069  engine get_engine() { return engine::query(*this); }
1070  };
1071 
1072  view(const primitive_desc &view_pd, primitive::at input) {
1073  mkldnn_primitive_t result;
1074  mkldnn_primitive_at_t inputs[] = { input.data };
1076  view_pd.get(), inputs, nullptr),
1077  "could not create a view primitive");
1078  reset(result);
1079  }
1080 
1081  view(memory input, memory::dims dims, memory::dims offsets) {
1082  mkldnn_primitive_t result;
1083  primitive_desc view_pd(input.get_primitive_desc(), dims,
1084  offsets);
1085  mkldnn_primitive_at_t inputs[] = { primitive::at(input).data };
1087  view_pd.get(), inputs, nullptr),
1088  "could not create a view primitive");
1089  reset(result);
1090  }
1091 };
1092 
1094 
1100 
1101 struct concat : public primitive {
1102  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1103  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1104  std::vector<memory::primitive_desc> inputs) {
1105  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1106  c_api_inputs.reserve(inputs.size());
1107  auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
1108  std::transform(inputs.begin(), inputs.end(),
1109  std::back_inserter(c_api_inputs), convert_to_c);
1110  return c_api_inputs;
1111  }
1112 
1113  primitive_desc(const memory::desc &output, int concat_dimension,
1114  std::vector<memory::primitive_desc> inputs) {
1115  mkldnn_primitive_desc_t result;
1116 
1117  auto c_api_inputs = cpp_to_c(inputs);
1118 
1120  &result, &output.data, (int)c_api_inputs.size(),
1121  concat_dimension, &c_api_inputs[0]),
1122  "could not create a concat primitive descriptor");
1123  reset(result);
1124  }
1125 
1126  primitive_desc(int concat_dimension,
1127  std::vector<memory::primitive_desc> inputs) {
1128  mkldnn_primitive_desc_t result;
1129 
1130  auto c_api_inputs = cpp_to_c(inputs);
1131 
1133  &result, nullptr, (int)c_api_inputs.size(),
1134  concat_dimension, &c_api_inputs[0]),
1135  "could not create a concat primitive descriptor");
1136  reset(result);
1137  }
1138 
1140  memory::primitive_desc adesc;
1141  mkldnn_primitive_desc_t cdesc;
1142  const_mkldnn_primitive_desc_t const_cdesc =
1145  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1146  "could not clone a dst primitive descriptor");
1147  adesc.reset(cdesc);
1148  return adesc;
1149  }
1150 
1151  engine get_engine() { return engine::query(*this); }
1152  };
1153 
1154  concat(const primitive_desc &concat_pd,
1155  std::vector<primitive::at> &inputs, const memory &output) {
1156  mkldnn_primitive_t result;
1157 
1158  std::vector<mkldnn_primitive_at_t> p_inputs;
1159  for (size_t i = 0; i < inputs.size(); i++)
1160  p_inputs.push_back(inputs[i].data);
1161  const_mkldnn_primitive_t outputs[] = { output.get() };
1162 
1164  concat_pd.get(), &p_inputs[0], outputs),
1165  "could not create a concat primitive");
1166  reset(result);
1167  }
1168 };
1169 
1171 
1177 
1178 struct sum : public primitive {
1179  struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1180  std::vector<const_mkldnn_primitive_desc_t> cpp_to_c(
1181  std::vector<memory::primitive_desc> inputs) {
1182  std::vector<const_mkldnn_primitive_desc_t> c_api_inputs;
1183  c_api_inputs.reserve(inputs.size());
1184  auto convert_to_c = [](memory::primitive_desc d) { return d.get();};
1185  std::transform(inputs.begin(), inputs.end(),
1186  std::back_inserter(c_api_inputs), convert_to_c);
1187  return c_api_inputs;
1188  }
1189 
1191  const std::vector<float> &scales,
1192  std::vector<memory::primitive_desc> inputs) {
1193  mkldnn_primitive_desc_t result;
1194 
1195  auto c_api_inputs = cpp_to_c(inputs);
1196 
1198  scales.size() == inputs.size() ? mkldnn_success
1200  "number of scales not equal to number of inputs");
1201 
1203  &result, &output.data, (int)c_api_inputs.size(),
1204  &scales[0], &c_api_inputs[0]),
1205  "could not create a sum primitive descriptor");
1206  reset(result);
1207  }
1208 
1209  primitive_desc(const std::vector<float> &scales,
1210  std::vector<memory::primitive_desc> inputs) {
1211  mkldnn_primitive_desc_t result;
1212 
1213  auto c_api_inputs = cpp_to_c(inputs);
1214 
1216  scales.size() == inputs.size() ? mkldnn_success
1218  "number of scales not equal to number of inputs");
1219 
1221  &result, nullptr, (int)c_api_inputs.size(), &scales[0],
1222  &c_api_inputs[0]),
1223  "could not create a sum primitive descriptor");
1224  reset(result);
1225  }
1226 
1228  memory::primitive_desc adesc;
1229  mkldnn_primitive_desc_t cdesc;
1230  const_mkldnn_primitive_desc_t const_cdesc =
1234  const_cdesc),
1235  "could not clone a dst primitive descriptor");
1236  adesc.reset(cdesc);
1237  return adesc;
1238  }
1239 
1240  engine get_engine() { return engine::query(*this); }
1241  };
1242 
1243  sum(const primitive_desc &sum_pd,
1244  std::vector<primitive::at> &inputs, const memory &output) {
1245  mkldnn_primitive_t result;
1246 
1247  std::vector<mkldnn_primitive_at_t> p_inputs;
1248  for (size_t i = 0; i < inputs.size(); i++)
1249  p_inputs.push_back(inputs[i].data);
1250  const_mkldnn_primitive_t outputs[] = { output.get() };
1251 
1253  sum_pd.get(), &p_inputs[0], outputs),
1254  "could not create a sum primitive");
1255  reset(result);
1256  }
1257 };
1258 
1260 
1262 
1265 
1268 
1270 struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
1272  const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
1273  mkldnn_primitive_desc_iterator_t iterator = nullptr;
1275  &iterator, desc, attr ? attr->get() : nullptr, e.get(),
1276  hint_fwd_pd);
1277  error::wrap_c_api(status,
1278  "could not create a primitive descriptor iterator");
1279  pd_iterator.reset(iterator);
1280  fetch_impl();
1281  }
1282 
1283  engine get_engine() { return engine::query(*this); }
1284 
1286  const_mkldnn_primitive_attr_t const_cattr;
1288  "could not get attributes");
1289  mkldnn_primitive_attr_t cattr;
1290  error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
1291  "could not clone attributes");
1292 
1293  primitive_attr attr;
1294  attr.reset(cattr);
1295  return attr;
1296  }
1297 
1299  const char *impl_info_str() const {
1300  const char *res;
1302  mkldnn_query_impl_info_str, 0, &res),
1303  "could not query implementation info string");
1304  return res;
1305  }
1306 
1313  bool next_impl() {
1315  pd_iterator.get());
1316  if (status == mkldnn_iterator_ends) return false;
1317  error::wrap_c_api(status, "primitive descriptor iterator next failed");
1318 
1319  fetch_impl();
1320  return true;
1321  }
1322 
1324  memory::primitive_desc query_mpd(query what, int idx = 0) const {
1325  std::vector<query> valid_w{input_pd, output_pd, src_pd, diff_src_pd,
1327  if (!std::any_of(valid_w.cbegin(), valid_w.cend(),
1328  [=](query q) { return what == q; }))
1329  throw error(mkldnn_invalid_arguments, "invalid memory query");
1330 
1331  const_mkldnn_primitive_desc_t const_cdesc
1333  mkldnn::convert_to_c(what), idx);
1334 
1335  // TODO: is there a better way to inform about this?
1336  if (const_cdesc == nullptr)
1337  throw error(mkldnn_not_required, "queried memory is not required");
1338 
1339  mkldnn_primitive_desc_t cdesc;
1340  error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
1341  "could not clone a memory primitive descriptor");
1342 
1344  ret.reset(cdesc);
1345  return ret;
1346  }
1347 
1348  // register specialized queries, e.g. src_primitive_desc()
1349 # define REG_QUERY_MPD(name, what, idx) \
1350  memory::primitive_desc name ## _primitive_desc() const \
1351  { return query_mpd(what ## _pd, idx); }
1352 
1353  private:
1354  handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
1355  void fetch_impl() {
1356  mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
1357  pd_iterator.get());
1359  "could not fetch a primitive descriptor from the iterator");
1360  reset(pd);
1361  }
1362 };
1363 
1365 
1371 
1373  struct desc {
1375  desc(prop_kind aprop_kind, algorithm aalgorithm,
1376  const memory::desc &src_desc,
1377  const memory::desc &weights_desc,
1378  const memory::desc &bias_desc,
1379  const memory::desc &dst_desc,
1380  const memory::dims strides,
1381  const memory::dims padding_l,
1382  const memory::dims padding_r,
1383  const padding_kind apadding_kind) {
1384  memory::validate_dims(strides);
1385  memory::validate_dims(padding_l);
1386  memory::validate_dims(padding_r);
1388  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1389  &src_desc.data, &weights_desc.data, &bias_desc.data,
1390  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1391  mkldnn::convert_to_c(apadding_kind)),
1392  "could not create a convolution forward descriptor");
1393  }
1394  desc(prop_kind aprop_kind, algorithm aalgorithm,
1395  const memory::desc &src_desc,
1396  const memory::desc &weights_desc,
1397  const memory::desc &dst_desc,
1398  const memory::dims strides,
1399  const memory::dims padding_l,
1400  const memory::dims padding_r,
1401  const padding_kind apadding_kind) {
1402  memory::validate_dims(strides);
1403  memory::validate_dims(padding_l);
1404  memory::validate_dims(padding_r);
1406  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1407  &src_desc.data, &weights_desc.data, nullptr,
1408  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1409  mkldnn::convert_to_c(apadding_kind)),
1410  "could not create a convolution forward descriptor");
1411  }
1412  desc(prop_kind aprop_kind, algorithm aalgorithm,
1413  const memory::desc &src_desc,
1414  const memory::desc &weights_desc,
1415  const memory::desc &bias_desc,
1416  const memory::desc &dst_desc,
1417  const memory::dims strides,
1418  const memory::dims dilates,
1419  const memory::dims padding_l,
1420  const memory::dims padding_r,
1421  const padding_kind apadding_kind) {
1422  memory::validate_dims(strides);
1423  memory::validate_dims(dilates);
1424  memory::validate_dims(padding_l);
1425  memory::validate_dims(padding_r);
1428  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1429  &src_desc.data, &weights_desc.data, &bias_desc.data,
1430  &dst_desc.data, &strides[0], &dilates[0],
1431  &padding_l[0], &padding_r[0],
1432  mkldnn::convert_to_c(apadding_kind)),
1433  "could not create a dilated convolution forward descriptor");
1434  }
1435  desc(prop_kind aprop_kind, algorithm aalgorithm,
1436  const memory::desc &src_desc,
1437  const memory::desc &weights_desc,
1438  const memory::desc &dst_desc,
1439  const memory::dims strides,
1440  const memory::dims dilates,
1441  const memory::dims padding_l,
1442  const memory::dims padding_r,
1443  const padding_kind apadding_kind) {
1444  memory::validate_dims(strides);
1445  memory::validate_dims(dilates);
1446  memory::validate_dims(padding_l);
1447  memory::validate_dims(padding_r);
1450  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1451  &src_desc.data, &weights_desc.data, nullptr,
1452  &dst_desc.data, &strides[0], &dilates[0],
1453  &padding_l[0], &padding_r[0],
1454  mkldnn::convert_to_c(apadding_kind)),
1455  "could not create a dilated convolution forward descriptor");
1456  }
1457  };
1458 
1460  primitive_desc(const desc &desc, const engine &e)
1461  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1462 
1463  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1464  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1465 
1466  REG_QUERY_MPD(src, src, 0);
1467  REG_QUERY_MPD(weights, weights, 0);
1468  REG_QUERY_MPD(bias, weights, 1);
1469  REG_QUERY_MPD(dst, dst, 0);
1470  };
1471 
1472  convolution_forward(const primitive_desc &aprimitive_desc,
1473  const primitive::at &src, const primitive::at &weights,
1474  const primitive::at &bias, const memory &dst) {
1475  mkldnn_primitive_t result;
1476  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1477  bias.data };
1478  const_mkldnn_primitive_t outputs[] = { dst.get() };
1480  aprimitive_desc.get(), inputs, outputs),
1481  "could not create a convolution forward bias primitive");
1482  reset(result);
1483  }
1484 
1485  convolution_forward(const primitive_desc &aprimitive_desc,
1486  const primitive::at &src, const primitive::at &weights,
1487  const memory &dst) {
1488  mkldnn_primitive_t result;
1489  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1490  const_mkldnn_primitive_t outputs[] = { dst.get() };
1491  check_num_parameters(aprimitive_desc.get(), 2, 1,
1492  "convolution forward");
1494  aprimitive_desc.get(), inputs, outputs),
1495  "could not create a convolution forward primitive");
1496  reset(result);
1497  }
1498 };
1499 
1501  struct desc {
1503  desc(algorithm aalgorithm,
1504  const memory::desc &diff_src_desc,
1505  const memory::desc &weights_desc,
1506  const memory::desc &diff_dst_desc,
1507  const memory::dims strides,
1508  const memory::dims padding_l,
1509  const memory::dims padding_r,
1510  const padding_kind apadding_kind) {
1511  memory::validate_dims(strides);
1512  memory::validate_dims(padding_l);
1513  memory::validate_dims(padding_r);
1515  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1516  &weights_desc.data, &diff_dst_desc.data,
1517  &strides[0], &padding_l[0], &padding_r[0],
1518  mkldnn::convert_to_c(apadding_kind)),
1519  "could not create a convolution backward data descriptor");
1520  }
1521  desc(algorithm aalgorithm,
1522  const memory::desc &diff_src_desc,
1523  const memory::desc &weights_desc,
1524  const memory::desc &diff_dst_desc,
1525  const memory::dims strides,
1526  const memory::dims dilates,
1527  const memory::dims padding_l,
1528  const memory::dims padding_r,
1529  const padding_kind apadding_kind) {
1530  memory::validate_dims(strides);
1531  memory::validate_dims(dilates);
1532  memory::validate_dims(padding_l);
1533  memory::validate_dims(padding_r);
1536  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1537  &weights_desc.data, &diff_dst_desc.data,
1538  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1539  mkldnn::convert_to_c(apadding_kind)),
1540  "could not create a convolution backward data descriptor");
1541  }
1542  };
1543 
1545  primitive_desc(const desc &desc, const engine &e,
1546  const convolution_forward::primitive_desc &hint_fwd_pd)
1547  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1548 
1549  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1550  const convolution_forward::primitive_desc &hint_fwd_pd)
1551  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1552 
1553  REG_QUERY_MPD(diff_src, diff_src, 0);
1554  REG_QUERY_MPD(weights, weights, 0);
1555  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1556  };
1557 
1559  const primitive::at &diff_dst, const primitive::at &weights,
1560  const memory &diff_src) {
1561  mkldnn_primitive_t result;
1562  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1563  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1564  check_num_parameters(aprimitive_desc.get(), 2, 1,
1565  "convolution backward data");
1567  aprimitive_desc.get(), inputs, outputs),
1568  "could not create a convolution backward data primitive");
1569  reset(result);
1570  }
1571 };
1572 
1574  struct desc {
1576  desc(algorithm aalgorithm,
1577  const memory::desc &src_desc,
1578  const memory::desc &diff_weights_desc,
1579  const memory::desc &diff_bias_desc,
1580  const memory::desc &diff_dst_desc,
1581  const memory::dims strides,
1582  const memory::dims padding_l,
1583  const memory::dims padding_r,
1584  const padding_kind apadding_kind) {
1585  memory::validate_dims(strides);
1586  memory::validate_dims(padding_l);
1587  memory::validate_dims(padding_r);
1589  &data, convert_to_c(aalgorithm), &src_desc.data,
1590  &diff_weights_desc.data, &diff_bias_desc.data,
1591  &diff_dst_desc.data,
1592  &strides[0], &padding_l[0], &padding_r[0],
1593  mkldnn::convert_to_c(apadding_kind)),
1594  "could not create a convolution backward weights descriptor");
1595  }
1596  desc(algorithm aalgorithm,
1597  const memory::desc &src_desc,
1598  const memory::desc &diff_weights_desc,
1599  const memory::desc &diff_dst_desc,
1600  const memory::dims strides,
1601  const memory::dims padding_l,
1602  const memory::dims padding_r,
1603  const padding_kind apadding_kind) {
1604  memory::validate_dims(strides);
1605  memory::validate_dims(padding_l);
1606  memory::validate_dims(padding_r);
1608  &data, convert_to_c(aalgorithm), &src_desc.data,
1609  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1610  &strides[0], &padding_l[0], &padding_r[0],
1611  mkldnn::convert_to_c(apadding_kind)),
1612  "could not create a convolution backward weights descriptor");
1613  }
1614  desc(algorithm aalgorithm,
1615  const memory::desc &src_desc,
1616  const memory::desc &diff_weights_desc,
1617  const memory::desc &diff_bias_desc,
1618  const memory::desc &diff_dst_desc,
1619  const memory::dims strides,
1620  const memory::dims dilates,
1621  const memory::dims padding_l,
1622  const memory::dims padding_r,
1623  const padding_kind apadding_kind) {
1624  memory::validate_dims(strides);
1625  memory::validate_dims(dilates);
1626  memory::validate_dims(padding_l);
1627  memory::validate_dims(padding_r);
1629  &data, convert_to_c(aalgorithm), &src_desc.data,
1630  &diff_weights_desc.data, &diff_bias_desc.data,
1631  &diff_dst_desc.data,
1632  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1633  mkldnn::convert_to_c(apadding_kind)),
1634  "could not create a convolution backward weights descriptor");
1635  }
1636  desc(algorithm aalgorithm,
1637  const memory::desc &src_desc,
1638  const memory::desc &diff_weights_desc,
1639  const memory::desc &diff_dst_desc,
1640  const memory::dims strides,
1641  const memory::dims dilates,
1642  const memory::dims padding_l,
1643  const memory::dims padding_r,
1644  const padding_kind apadding_kind) {
1645  memory::validate_dims(strides);
1646  memory::validate_dims(dilates);
1647  memory::validate_dims(padding_l);
1648  memory::validate_dims(padding_r);
1650  &data, convert_to_c(aalgorithm), &src_desc.data,
1651  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1652  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1653  mkldnn::convert_to_c(apadding_kind)),
1654  "could not create a convolution backward weights descriptor");
1655  }
1656 
1657  };
1658 
1660  primitive_desc(const desc &desc, const engine &e,
1661  const convolution_forward::primitive_desc &hint_fwd_pd)
1662  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1663 
1664  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1665  const convolution_forward::primitive_desc &hint_fwd_pd)
1666  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1667 
1668  REG_QUERY_MPD(src, src, 0);
1669  REG_QUERY_MPD(diff_weights, diff_weights, 0);
1670  REG_QUERY_MPD(diff_bias, diff_weights, 1);
1671  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1672  };
1673 
1675  const primitive::at &src, const primitive::at &diff_dst,
1676  const memory &diff_weights, const memory &diff_bias) {
1677  mkldnn_primitive_t result;
1678  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1679  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
1680  diff_bias.get() };
1681  check_num_parameters(aprimitive_desc.get(), 2, 2,
1682  "convolution backward weights");
1684  aprimitive_desc.get(), inputs, outputs),
1685  "could not create a convolution backward weights primitive");
1686  reset(result);
1687  }
1689  const primitive::at &src, const primitive::at &diff_dst,
1690  const memory &diff_weights) {
1691  mkldnn_primitive_t result;
1692  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
1693  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
1694  check_num_parameters(aprimitive_desc.get(), 2, 1,
1695  "convolution backward weights");
1697  aprimitive_desc.get(), inputs, outputs),
1698  "could not create a convolution backward weights primitive");
1699  reset(result);
1700  }
1701 };
1702 
1704 //
1710 
1712  struct desc {
1714  desc(prop_kind aprop_kind, algorithm aalgorithm,
1715  const memory::desc &src_desc,
1716  const memory::desc &weights_desc,
1717  const memory::desc &bias_desc,
1718  const memory::desc &dst_desc,
1719  const memory::dims strides,
1720  const memory::dims padding_l,
1721  const memory::dims padding_r,
1722  const padding_kind apadding_kind) {
1723  memory::validate_dims(strides);
1724  memory::validate_dims(padding_l);
1725  memory::validate_dims(padding_r);
1727  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1728  &src_desc.data, &weights_desc.data, &bias_desc.data,
1729  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1730  mkldnn::convert_to_c(apadding_kind)),
1731  "could not create a deconvolution forward descriptor");
1732  }
1733  desc(prop_kind aprop_kind, algorithm aalgorithm,
1734  const memory::desc &src_desc,
1735  const memory::desc &weights_desc,
1736  const memory::desc &dst_desc,
1737  const memory::dims strides,
1738  const memory::dims padding_l,
1739  const memory::dims padding_r,
1740  const padding_kind apadding_kind) {
1741  memory::validate_dims(strides);
1742  memory::validate_dims(padding_l);
1743  memory::validate_dims(padding_r);
1745  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1746  &src_desc.data, &weights_desc.data, nullptr,
1747  &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
1748  mkldnn::convert_to_c(apadding_kind)),
1749  "could not create a deconvolution forward descriptor");
1750  }
1751  desc(prop_kind aprop_kind, algorithm aalgorithm,
1752  const memory::desc &src_desc,
1753  const memory::desc &weights_desc,
1754  const memory::desc &bias_desc,
1755  const memory::desc &dst_desc,
1756  const memory::dims strides,
1757  const memory::dims dilates,
1758  const memory::dims padding_l,
1759  const memory::dims padding_r,
1760  const padding_kind apadding_kind) {
1761  memory::validate_dims(strides);
1762  memory::validate_dims(dilates);
1763  memory::validate_dims(padding_l);
1764  memory::validate_dims(padding_r);
1766  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1767  &src_desc.data, &weights_desc.data, &bias_desc.data,
1768  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1769  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1770  "could not create a dilated deconvolution forward descriptor");
1771  }
1772  desc(prop_kind aprop_kind, algorithm aalgorithm,
1773  const memory::desc &src_desc,
1774  const memory::desc &weights_desc,
1775  const memory::desc &dst_desc,
1776  const memory::dims strides,
1777  const memory::dims dilates,
1778  const memory::dims padding_l,
1779  const memory::dims padding_r,
1780  const padding_kind apadding_kind) {
1781  memory::validate_dims(strides);
1782  memory::validate_dims(dilates);
1783  memory::validate_dims(padding_l);
1784  memory::validate_dims(padding_r);
1786  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
1787  &src_desc.data, &weights_desc.data, nullptr,
1788  &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
1789  &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
1790  "could not create a dilated deconvolution forward descriptor");
1791  }
1792  };
1793 
1795  primitive_desc(const desc &desc, const engine &e)
1796  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
1797 
1798  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
1799  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
1800 
1801  REG_QUERY_MPD(src, src, 0);
1802  REG_QUERY_MPD(weights, weights, 0);
1803  REG_QUERY_MPD(bias, weights, 1);
1804  REG_QUERY_MPD(dst, dst, 0);
1805  };
1806 
1807  deconvolution_forward(const primitive_desc &aprimitive_desc,
1808  const primitive::at &src, const primitive::at &weights,
1809  const primitive::at &bias, const memory &dst) {
1810  mkldnn_primitive_t result;
1811  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
1812  bias.data };
1813  const_mkldnn_primitive_t outputs[] = { dst.get() };
1814  check_num_parameters(aprimitive_desc.get(), 3, 1,
1815  "deconvolution forward");
1817  aprimitive_desc.get(), inputs, outputs),
1818  "could not create a deconvolution forward bias primitive");
1819  reset(result);
1820  }
1821 
1822  deconvolution_forward(const primitive_desc &aprimitive_desc,
1823  const primitive::at &src, const primitive::at &weights,
1824  const memory &dst) {
1825  mkldnn_primitive_t result;
1826  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
1827  const_mkldnn_primitive_t outputs[] = { dst.get() };
1828  check_num_parameters(aprimitive_desc.get(), 2, 1,
1829  "deconvolution forward");
1831  aprimitive_desc.get(), inputs, outputs),
1832  "could not create a deconvolution forward primitive");
1833  reset(result);
1834  }
1835 };
1836 
1838  struct desc {
1840  desc(algorithm aalgorithm,
1841  const memory::desc &diff_src_desc,
1842  const memory::desc &weights_desc,
1843  const memory::desc &diff_dst_desc,
1844  const memory::dims strides,
1845  const memory::dims padding_l,
1846  const memory::dims padding_r,
1847  const padding_kind apadding_kind) {
1848  memory::validate_dims(strides);
1849  memory::validate_dims(padding_l);
1850  memory::validate_dims(padding_r);
1852  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1853  &weights_desc.data, &diff_dst_desc.data,
1854  &strides[0], &padding_l[0], &padding_r[0],
1855  mkldnn::convert_to_c(apadding_kind)),
1856  "could not create a deconvolution backward data descriptor");
1857  }
1858  desc(algorithm aalgorithm,
1859  const memory::desc &diff_src_desc,
1860  const memory::desc &weights_desc,
1861  const memory::desc &diff_dst_desc,
1862  const memory::dims strides,
1863  const memory::dims dilates,
1864  const memory::dims padding_l,
1865  const memory::dims padding_r,
1866  const padding_kind apadding_kind) {
1867  memory::validate_dims(strides);
1868  memory::validate_dims(dilates);
1869  memory::validate_dims(padding_l);
1870  memory::validate_dims(padding_r);
1872  &data, convert_to_c(aalgorithm), &diff_src_desc.data,
1873  &weights_desc.data, &diff_dst_desc.data,
1874  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1875  mkldnn::convert_to_c(apadding_kind)),
1876  "could not create a dilated deconvolution backward data descriptor");
1877  }
1878  };
1879 
1881  primitive_desc(const desc &desc, const engine &e,
1882  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1883  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1884 
1885  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
1886  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1887  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
1888 
1889  REG_QUERY_MPD(diff_src, diff_src, 0);
1890  REG_QUERY_MPD(weights, weights, 0);
1891  REG_QUERY_MPD(diff_dst, diff_dst, 0);
1892  };
1893 
1895  const primitive::at &diff_dst, const primitive::at &weights,
1896  const memory &diff_src) {
1897  mkldnn_primitive_t result;
1898  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
1899  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
1900  check_num_parameters(aprimitive_desc.get(), 2, 1,
1901  "deconvolution backward data");
1903  aprimitive_desc.get(), inputs, outputs),
1904  "could not create a deconvolution backward data primitive");
1905  reset(result);
1906  }
1907 };
1908 
1910  struct desc {
1912  desc(algorithm aalgorithm,
1913  const memory::desc &src_desc,
1914  const memory::desc &diff_weights_desc,
1915  const memory::desc &diff_bias_desc,
1916  const memory::desc &diff_dst_desc,
1917  const memory::dims strides,
1918  const memory::dims padding_l,
1919  const memory::dims padding_r,
1920  const padding_kind apadding_kind) {
1921  memory::validate_dims(strides);
1922  memory::validate_dims(padding_l);
1923  memory::validate_dims(padding_r);
1925  &data, convert_to_c(aalgorithm), &src_desc.data,
1926  &diff_weights_desc.data, &diff_bias_desc.data,
1927  &diff_dst_desc.data,
1928  &strides[0], &padding_l[0], &padding_r[0],
1929  mkldnn::convert_to_c(apadding_kind)),
1930  "could not create a deconvolution backward weights descriptor");
1931  }
1932  desc(algorithm aalgorithm,
1933  const memory::desc &src_desc,
1934  const memory::desc &diff_weights_desc,
1935  const memory::desc &diff_dst_desc,
1936  const memory::dims strides,
1937  const memory::dims padding_l,
1938  const memory::dims padding_r,
1939  const padding_kind apadding_kind) {
1940  memory::validate_dims(strides);
1941  memory::validate_dims(padding_l);
1942  memory::validate_dims(padding_r);
1944  &data, convert_to_c(aalgorithm), &src_desc.data,
1945  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1946  &strides[0], &padding_l[0], &padding_r[0],
1947  mkldnn::convert_to_c(apadding_kind)),
1948  "could not create a deconvolution backward weights descriptor");
1949  }
1950  desc(algorithm aalgorithm,
1951  const memory::desc &src_desc,
1952  const memory::desc &diff_weights_desc,
1953  const memory::desc &diff_bias_desc,
1954  const memory::desc &diff_dst_desc,
1955  const memory::dims strides,
1956  const memory::dims dilates,
1957  const memory::dims padding_l,
1958  const memory::dims padding_r,
1959  const padding_kind apadding_kind) {
1960  memory::validate_dims(strides);
1961  memory::validate_dims(dilates);
1962  memory::validate_dims(padding_l);
1963  memory::validate_dims(padding_r);
1965  &data, convert_to_c(aalgorithm), &src_desc.data,
1966  &diff_weights_desc.data, &diff_bias_desc.data,
1967  &diff_dst_desc.data,
1968  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1969  mkldnn::convert_to_c(apadding_kind)),
1970  "could not create a dilated deconvolution backward weights descriptor");
1971  }
1972  desc(algorithm aalgorithm,
1973  const memory::desc &src_desc,
1974  const memory::desc &diff_weights_desc,
1975  const memory::desc &diff_dst_desc,
1976  const memory::dims strides,
1977  const memory::dims dilates,
1978  const memory::dims padding_l,
1979  const memory::dims padding_r,
1980  const padding_kind apadding_kind) {
1981  memory::validate_dims(strides);
1982  memory::validate_dims(dilates);
1983  memory::validate_dims(padding_l);
1984  memory::validate_dims(padding_r);
1986  &data, convert_to_c(aalgorithm), &src_desc.data,
1987  &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
1988  &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
1989  mkldnn::convert_to_c(apadding_kind)),
1990  "could not create a dilated deconvolution backward weights descriptor");
1991  }
1992  };
1993 
1995  primitive_desc(const desc &desc, const engine &e,
1996  const deconvolution_forward::primitive_desc &hint_fwd_pd)
1997  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
1998 
1999  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2000  const deconvolution_forward::primitive_desc &hint_fwd_pd)
2001  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2002 
2003  REG_QUERY_MPD(src, src, 0);
2004  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2005  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2006  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2007  };
2008 
2010  const primitive::at &src, const primitive::at &diff_dst,
2011  const memory &diff_weights, const memory &diff_bias) {
2012  mkldnn_primitive_t result;
2013  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2014  const_mkldnn_primitive_t outputs[] = { diff_weights.get(),
2015  diff_bias.get() };
2016  check_num_parameters(aprimitive_desc.get(), 2, 2,
2017  "deconvolution backward weights");
2019  aprimitive_desc.get(), inputs, outputs),
2020  "could not create a deconvolution backward weights primitive");
2021  reset(result);
2022  }
2024  const primitive::at &src, const primitive::at &diff_dst,
2025  const memory &diff_weights) {
2026  mkldnn_primitive_t result;
2027  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2028  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2029  check_num_parameters(aprimitive_desc.get(), 2, 1,
2030  "deconvolution backward weights");
2032  aprimitive_desc.get(), inputs, outputs),
2033  "could not create a deconvolution backward weights primitive");
2034  reset(result);
2035  }
2036 };
2037 
2039 
2046 
2047 struct lrn_forward : public primitive {
2048  struct desc {
2050  desc(prop_kind aprop_kind, algorithm aalgorithm,
2051  const memory::desc &src_desc,
2052  int local_size, float alpha, float beta, float k)
2053  {
2055  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2056  &src_desc.data, local_size, alpha, beta, k),
2057  "could not create a lrn forward descriptor");
2058  }
2059  desc(prop_kind aprop_kind, algorithm aalgorithm,
2060  const memory::desc &src_desc,
2061  int local_size, float alpha, float beta)
2062  {
2064  mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
2065  &src_desc.data, local_size, alpha, beta, float(1.0)),
2066  "could not create a lrn forward descriptor");
2067  }
2068  };
2069 
2071  primitive_desc(const desc &desc, const engine &e)
2072  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2073 
2074  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2075  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2076 
2077  REG_QUERY_MPD(src, src, 0);
2078  REG_QUERY_MPD(dst, dst, 0);
2079  REG_QUERY_MPD(workspace, workspace, 0);
2080  };
2081 
2082  lrn_forward(const primitive_desc &aprimitive_desc,
2083  const primitive::at &src, const memory &workspace,
2084  const memory &dst) {
2085  mkldnn_primitive_t result;
2086  mkldnn_primitive_at_t inputs[] = { src.data };
2087  const_mkldnn_primitive_t outputs[] = { dst.get(),
2088  workspace.get() };
2089  check_num_parameters(aprimitive_desc.get(), 1, 2, "lrn forward");
2091  aprimitive_desc.get(), inputs, outputs),
2092  "could not create a lrn forward primitive");
2093  reset(result);
2094  }
2095 
2096  lrn_forward(const primitive_desc &aprimitive_desc,
2097  const primitive::at &src, const memory &dst) {
2098  mkldnn_primitive_t result;
2099  mkldnn_primitive_at_t inputs[] = { src.data };
2100  const_mkldnn_primitive_t outputs[] = { dst.get() };
2101  check_num_parameters(aprimitive_desc.get(), 1, 1, "lrn forward");
2103  aprimitive_desc.get(), inputs, outputs),
2104  "could not create a lrn forward primitive");
2105  reset(result);
2106  }
2107 };
2108 
2109 struct lrn_backward : public primitive {
2110  struct desc {
2112  desc(algorithm aalgorithm,
2113  const memory::desc &data_desc,
2114  const memory::desc &diff_data_desc,
2115  int local_size, float alpha, float beta, float k)
2116  {
2118  convert_to_c(aalgorithm), &diff_data_desc.data,
2119  &data_desc.data, local_size, alpha, beta, k),
2120  "could not create a lrn backward descriptor");
2121  }
2122  desc(algorithm aalgorithm,
2123  const memory::desc &data_desc,
2124  const memory::desc &diff_data_desc,
2125  int local_size, float alpha, float beta)
2126  {
2128  convert_to_c(aalgorithm), &diff_data_desc.data,
2129  &data_desc.data, local_size, alpha, beta, float(1.0)),
2130  "could not create a lrn backward descriptor");
2131  }
2132  };
2133 
2135  primitive_desc(const desc &desc, const engine &e,
2136  const lrn_forward::primitive_desc &hint_fwd_pd)
2137  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2138 
2139  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2140  const lrn_forward::primitive_desc &hint_fwd_pd)
2141  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2142 
2143  REG_QUERY_MPD(diff_src, diff_src, 0);
2144  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2145  REG_QUERY_MPD(workspace, workspace, 0);
2146  };
2147 
2148  lrn_backward(const primitive_desc &aprimitive_desc,
2149  const primitive::at &src, const primitive::at &diff_dst,
2150  const primitive::at &workspace, const memory &diff_src) {
2151  mkldnn_primitive_t result;
2152  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data,
2153  workspace.data };
2154  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2155  check_num_parameters(aprimitive_desc.get(), 3, 1, "lrn backward");
2157  aprimitive_desc.get(), inputs, outputs),
2158  "could not create a lrn backward primitive");
2159  reset(result);
2160  }
2161 
2162  lrn_backward(const primitive_desc &aprimitive_desc,
2163  const primitive::at &src, const primitive::at &diff_dst,
2164  const memory &diff_src) {
2165  mkldnn_primitive_t result;
2166  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2167  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2168  check_num_parameters(aprimitive_desc.get(), 2, 1, "lrn backward");
2170  aprimitive_desc.get(), inputs, outputs),
2171  "could not create a lrn backward primitive");
2172  reset(result);
2173  }
2174 };
2175 
2177 
2183 
2184 struct pooling_forward : public primitive {
2185  struct desc {
2187  desc(prop_kind aprop_kind, algorithm aalgorithm,
2188  const memory::desc &src_desc,
2189  const memory::desc &dst_desc,
2190  const memory::dims strides,
2191  const memory::dims kernel,
2192  const memory::dims padding_l,
2193  const memory::dims padding_r,
2194  const padding_kind apadding_kind) {
2195  memory::validate_dims(strides);
2196  memory::validate_dims(kernel);
2197  memory::validate_dims(padding_l);
2198  memory::validate_dims(padding_r);
2200  mkldnn::convert_to_c(aprop_kind),
2201  convert_to_c(aalgorithm),
2202  &src_desc.data, &dst_desc.data,
2203  &strides[0], &kernel[0],
2204  &padding_l[0], &padding_r[0],
2205  mkldnn::convert_to_c(apadding_kind)),
2206  "could not init a forward pooling descriptor");
2207  }
2208  };
2209 
2211  primitive_desc(const desc &desc, const engine &e)
2212  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2213 
2214  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2215  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2216 
2217  REG_QUERY_MPD(src, src, 0);
2218  REG_QUERY_MPD(dst, dst, 0);
2219  REG_QUERY_MPD(workspace, workspace, 0);
2220  };
2221 
2222  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2223  const memory &dst) {
2224  mkldnn_primitive_t result;
2225  mkldnn_primitive_at_t inputs[] = { src.data };
2226  const_mkldnn_primitive_t outputs[] = { dst.get(), nullptr };
2227  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling forward");
2229  aprimitive_desc.get(), inputs, outputs),
2230  "could not create a pooling forward primitive");
2231  reset(result);
2232  }
2233 
2234  pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src,
2235  const memory &dst, const memory &workspace) {
2236  mkldnn_primitive_t result;
2237  mkldnn_primitive_at_t inputs[] = { src.data };
2238  const_mkldnn_primitive_t outputs[] = { dst.get(), workspace.get() };
2239  check_num_parameters(aprimitive_desc.get(), 1, 2, "pooling forward");
2241  aprimitive_desc.get(), inputs, outputs),
2242  "could not create a pooling forward primitive");
2243  reset(result);
2244  }
2245 };
2246 
2247 struct pooling_backward : public primitive {
2248  struct desc {
2250  desc(algorithm aalgorithm,
2251  const memory::desc &diff_src_desc,
2252  const memory::desc &diff_dst_desc,
2253  const memory::dims &strides,
2254  const memory::dims &kernel,
2255  const memory::dims &padding_l,
2256  const memory::dims &padding_r,
2257  const padding_kind apadding_kind) {
2258  memory::validate_dims(strides);
2259  memory::validate_dims(kernel);
2260  memory::validate_dims(padding_l);
2261  memory::validate_dims(padding_r);
2263  convert_to_c(aalgorithm),
2264  &diff_src_desc.data, &diff_dst_desc.data,
2265  &strides[0], &kernel[0],
2266  &padding_l[0], &padding_r[0],
2267  mkldnn::convert_to_c(apadding_kind)),
2268  "could not init a backward pooling descriptor");
2269  }
2270  };
2271 
2273  primitive_desc(const desc &desc, const engine &e,
2274  const pooling_forward::primitive_desc &hint_fwd_pd)
2275  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2276 
2277  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2278  const pooling_forward::primitive_desc &hint_fwd_pd)
2279  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2280 
2281  REG_QUERY_MPD(diff_src, diff_src, 0);
2282  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2283  REG_QUERY_MPD(workspace, workspace, 0);
2284  };
2285 
2286  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2287  const memory &diff_src) {
2288  mkldnn_primitive_t result;
2289  mkldnn_primitive_at_t inputs[] = { diff_dst.data };
2290  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2291  check_num_parameters(aprimitive_desc.get(), 1, 1, "pooling backward");
2293  aprimitive_desc.get(), inputs, outputs),
2294  "could not create a pooling backward primitive");
2295  reset(result);
2296  }
2297 
2298  pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst,
2299  const primitive::at &workspace, const memory &diff_src) {
2300  mkldnn_primitive_t result;
2301  mkldnn_primitive_at_t inputs[] = { diff_dst.data, workspace.data };
2302  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2303  check_num_parameters(aprimitive_desc.get(), 2, 1, "pooling backward");
2305  aprimitive_desc.get(), inputs, outputs),
2306  "could not create a pooling backward primitive");
2307  reset(result);
2308  }
2309 };
2310 
2312 
2319 
2320 struct eltwise_forward : public primitive {
2321  struct desc {
2323  template <typename T>
2324  desc(prop_kind aprop_kind, algorithm alg_kind,
2325  const memory::desc &src_desc, T alpha = 0, T beta = 0) {
2327  mkldnn::convert_to_c(aprop_kind),
2328  mkldnn::convert_to_c(alg_kind), &src_desc.data,
2329  static_cast<float>(alpha), static_cast<float>(beta)),
2330  "could not create a eltwise forward descriptor");
2331  }
2332  };
2333 
2335  primitive_desc(const desc &desc, const engine &e)
2336  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2337 
2338  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2339  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2340 
2341  REG_QUERY_MPD(src, src, 0);
2342  REG_QUERY_MPD(dst, dst, 0);
2343  };
2344 
2345  eltwise_forward(const primitive_desc &aprimitive_desc,
2346  const primitive::at &src, const memory &dst) {
2347  mkldnn_primitive_t result;
2348  mkldnn_primitive_at_t inputs[] = { src.data };
2349  const_mkldnn_primitive_t outputs[] = { dst.get() };
2350  check_num_parameters(aprimitive_desc.get(), 1, 1, "eltwise forward");
2352  aprimitive_desc.get(), inputs, outputs),
2353  "could not create a eltwise forward primitive");
2354  reset(result);
2355  }
2356 };
2357 
2358 struct eltwise_backward : public primitive {
2359  struct desc {
2361 
2362  template <typename T>
2363  desc(algorithm alg_kind, const memory::desc &diff_data_desc,
2364  const memory::desc &data_desc, T alpha = 0, T beta = 0) {
2366  mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
2367  &data_desc.data, static_cast<float>(alpha),
2368  static_cast<float>(beta)),
2369  "could not create a eltwise backward descriptor");
2370  }
2371  };
2372 
2374  primitive_desc(const desc &desc, const engine &e,
2375  const eltwise_forward::primitive_desc &hint_fwd_pd)
2376  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2377 
2378  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2379  const eltwise_forward::primitive_desc &hint_fwd_pd)
2380  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2381 
2382  REG_QUERY_MPD(src, src, 0);
2383  REG_QUERY_MPD(diff_src, diff_src, 0);
2384  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2385  };
2386 
2387  eltwise_backward(const primitive_desc &aprimitive_desc,
2388  const primitive::at &src, const primitive::at &diff_dst,
2389  const memory &diff_src) {
2390  mkldnn_primitive_t result;
2391  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2392  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2393  check_num_parameters(aprimitive_desc.get(), 2, 1, "eltwise backward");
2395  aprimitive_desc.get(), inputs, outputs),
2396  "could not create a eltwise backward primitive");
2397  reset(result);
2398  }
2399 };
2400 
2402 
2408 
2409 struct softmax_forward : public primitive {
2410  struct desc {
2412  desc(prop_kind aprop_kind, const memory::desc &data_desc,
2413  int softmax_axis) {
2415  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
2416  softmax_axis),
2417  "could not create a softmax forward descriptor");
2418  }
2419  };
2420 
2422  primitive_desc(const desc &desc, const engine &e)
2423  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2424 
2425  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2426  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2427 
2428  REG_QUERY_MPD(src, src, 0);
2429  REG_QUERY_MPD(dst, dst, 0);
2430  };
2431 
2432  softmax_forward(const primitive_desc &aprimitive_desc,
2433  const primitive::at &src, const memory &dst) {
2434  mkldnn_primitive_t result;
2435  mkldnn_primitive_at_t inputs[] = { src.data };
2436  const_mkldnn_primitive_t outputs[] = { dst.get() };
2437  check_num_parameters(aprimitive_desc.get(), 1, 1, "softmax forward");
2439  aprimitive_desc.get(), inputs, outputs),
2440  "could not create a softmax forward primitive");
2441  reset(result);
2442  }
2443 };
2444 
2445 struct softmax_backward : public primitive {
2446  struct desc {
2448  desc(const memory::desc &diff_desc, const memory::desc &data_desc,
2449  int softmax_axis) {
2451  &diff_desc.data, &data_desc.data, softmax_axis),
2452  "could not init a backward softmax descriptor");
2453  }
2454  };
2455 
2457  primitive_desc(const desc &desc, const engine &e,
2458  const softmax_forward::primitive_desc &hint_fwd_pd)
2459  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2460 
2461  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2462  const softmax_forward::primitive_desc &hint_fwd_pd)
2463  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2464 
2465  REG_QUERY_MPD(dst, dst, 0);
2466  REG_QUERY_MPD(diff_src, diff_src, 0);
2467  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2468  REG_QUERY_MPD(workspace, workspace, 0);
2469  };
2470 
2471  softmax_backward(const primitive_desc &aprimitive_desc,
2472  const primitive::at &dst, const primitive::at &diff_dst,
2473  const memory &diff_src) {
2474  mkldnn_primitive_t result;
2475  mkldnn_primitive_at_t inputs[] = { dst.data, diff_dst.data };
2476  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2478  aprimitive_desc.get(), inputs, outputs),
2479  "could not create a softmax backward primitive");
2480  reset(result);
2481  }
2482 };
2483 
2485 
2491 
2493  struct desc {
2495  template <typename T>
2496  desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
2497  unsigned flags) {
2500  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2501  static_cast<float>(epsilon), flags),
2502  "could not create a batch normalization forward descriptor");
2503  }
2504  };
2505 
2507  primitive_desc(const desc &desc, const engine &e)
2508  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2509 
2510  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2511  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2512 
2513  REG_QUERY_MPD(src, src, 0);
2514  REG_QUERY_MPD(weights, weights, 0);
2515  REG_QUERY_MPD(dst, dst, 0);
2516  REG_QUERY_MPD(workspace, workspace, 0);
2517 
2519  { return stat_primitive_desc(mean); }
2521  { return stat_primitive_desc(var); }
2522 
2523  private:
2524  enum { mean = 1, var = 2, };
2525  memory::primitive_desc stat_primitive_desc(int kind) const {
2529  "could not get a batch-normalization descriptor");
2530  return query_mpd(p->flags & use_global_stats ? src_pd : dst_pd, kind);
2531  }
2532  };
2533 
2535  const primitive::at &src, const primitive::at &mean,
2536  const primitive::at &variance, const primitive::at &weights,
2537  const memory &dst) {
2538  mkldnn_primitive_t result;
2539  mkldnn_primitive_at_t inputs[] = { src.data,
2540  mean.data, variance.data, weights.data };
2541  const_mkldnn_primitive_t outputs[] = { dst.get() };
2542  check_num_parameters(aprimitive_desc.get(), 4, 1,
2543  "batch normalization forward");
2545  aprimitive_desc.get(), inputs, outputs),
2546  "could not create a batch normalization forward primitive");
2547  reset(result);
2548  }
2549 
2551  const primitive::at &src, const primitive::at &mean,
2552  const primitive::at &variance, const memory &dst) {
2553  mkldnn_primitive_t result;
2554  mkldnn_primitive_at_t inputs[] = { src.data,
2555  mean.data, variance.data };
2556  const_mkldnn_primitive_t outputs[] = { dst.get() };
2557  check_num_parameters(aprimitive_desc.get(), 3, 1,
2558  "batch normalization forward");
2560  aprimitive_desc.get(), inputs, outputs),
2561  "could not create a batch normalization forward primitive");
2562  reset(result);
2563  }
2564 
2573  const primitive::at &src, const primitive::at &weights,
2574  const memory &dst, const memory &mean, const memory &variance) {
2575  mkldnn_primitive_t result;
2576  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2577  const_mkldnn_primitive_t outputs[] = { dst.get(),
2578  mean.get(), variance.get() };
2579  check_num_parameters(aprimitive_desc.get(), 2, 3,
2580  "batch normalization forward");
2582  aprimitive_desc.get(), inputs, outputs),
2583  "could not create a batch normalization forward primitive");
2584  reset(result);
2585  }
2586 
2588  const primitive::at &src, const primitive::at &weights,
2589  const memory &dst, const memory &mean, const memory &variance,
2590  const memory &workspace) {
2591  mkldnn_primitive_t result;
2592  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2593  const_mkldnn_primitive_t outputs[] = { dst.get(),
2594  mean.get(), variance.get(), workspace.get() };
2595  check_num_parameters(aprimitive_desc.get(), 2, 4,
2596  "batch normalization forward");
2598  aprimitive_desc.get(), inputs, outputs),
2599  "could not create a batch normalization forward primitive");
2600  reset(result);
2601  }
2602 
2604  const primitive::at &src, const memory &dst, const memory &mean,
2605  const memory &variance) {
2606  mkldnn_primitive_t result;
2607  mkldnn_primitive_at_t inputs[] = { src.data };
2608  const_mkldnn_primitive_t outputs[] = { dst.get(),
2609  mean.get(), variance.get() };
2610  check_num_parameters(aprimitive_desc.get(), 1, 3,
2611  "batch normalization forward");
2613  aprimitive_desc.get(), inputs, outputs),
2614  "could not create a batch normalization forward primitive");
2615  reset(result);
2616  }
2617 
2629  const primitive::at &src, const memory &dst, const memory &mean,
2630  const memory &variance, const memory &workspace) {
2631  mkldnn_primitive_t result;
2632  mkldnn_primitive_at_t inputs[2] = { src.data };
2633  const_mkldnn_primitive_t outputs[4] = { dst.get(),
2634  mean.get(), variance.get(), workspace.get() };
2635 
2636  if (1) { // check whether this is the `wrong` constructor
2637  const int n_inputs_expected = mkldnn_primitive_desc_query_s32(
2638  aprimitive_desc.get(), mkldnn_query_num_of_inputs_s32, 0);
2639  const int n_outputs_expected = mkldnn_primitive_desc_query_s32(
2640  aprimitive_desc.get(), mkldnn_query_num_of_outputs_s32, 0);
2641  if (n_inputs_expected == 2 && n_outputs_expected == 3) {
2642  // shift parameters, get rid of workspace, and add weights...
2643  auto _weights = dst;
2644  inputs[1] = {_weights.get(), 0};
2645 
2646  auto _dst = mean, _mean = variance, _variance = workspace;
2647  outputs[0] = _dst.get();
2648  outputs[1] = _mean.get();
2649  outputs[2] = _variance.get();
2650  outputs[3] = nullptr;
2651  }
2652  }
2654  aprimitive_desc.get(), inputs, outputs),
2655  "could not create a batch normalization forward primitive");
2656  reset(result);
2657  }
2658 
2660  const primitive::at &src, const primitive::at &weights,
2661  const memory &dst) {
2662  mkldnn_primitive_t result;
2663  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2664  const_mkldnn_primitive_t outputs[] = { dst.get() };
2665  check_num_parameters(aprimitive_desc.get(), 2, 1,
2666  "batch normalization forward");
2668  aprimitive_desc.get(), inputs, outputs),
2669  "could not create a batch normalization forward primitive");
2670  reset(result);
2671  }
2672 
2674  const primitive::at &src, const memory &dst) {
2675  mkldnn_primitive_t result;
2676  mkldnn_primitive_at_t inputs[] = { src.data };
2677  const_mkldnn_primitive_t outputs[] = { dst.get() };
2678  check_num_parameters(aprimitive_desc.get(), 1, 1,
2679  "batch normalization forward");
2681  aprimitive_desc.get(), inputs, outputs),
2682  "could not create a batch normalization forward primitive");
2683  reset(result);
2684  }
2685 };
2686 
2688  struct desc {
2690  template <typename T>
2691  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
2692  const memory::desc &data_desc, T epsilon, unsigned flags) {
2695  mkldnn::convert_to_c(aprop_kind),
2696  &diff_data_desc.data, &data_desc.data,
2697  static_cast<float>(epsilon), flags),
2698  "could not create a batch normalization backward descriptor");
2699  }
2700  };
2701 
2703  primitive_desc(const desc &desc, const engine &e,
2705  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2706 
2707  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2709  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2710 
2711  REG_QUERY_MPD(src, src, 0);
2712  REG_QUERY_MPD(mean, src, 1);
2713  REG_QUERY_MPD(variance, src, 2);
2714  REG_QUERY_MPD(weights, weights, 0);
2715  REG_QUERY_MPD(dst, dst, 0);
2716  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2717  REG_QUERY_MPD(workspace, workspace, 0);
2718 
2719  REG_QUERY_MPD(diff_src, diff_src, 0);
2720  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2721  };
2722 
2723  // Prop_kind == backward
2725  const primitive::at &src, const primitive::at &mean,
2726  const primitive::at &variance, const primitive::at &diff_dst,
2727  const primitive::at &weights, const memory &diff_src,
2728  const memory &diff_weights) {
2729  mkldnn_primitive_t result;
2730  mkldnn_primitive_at_t inputs[] = { src.data,
2731  mean.data, variance.data, diff_dst.data, weights.data };
2732  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2733  diff_weights.get() };
2734  check_num_parameters(aprimitive_desc.get(), 5, 2,
2735  "batch normalization backward");
2737  aprimitive_desc.get(), inputs, outputs),
2738  "could not create a batch normalization backward primitive");
2739  reset(result);
2740  }
2741 
2742  // Prop_kind == backward (+ws)
2744  const primitive::at &src, const primitive::at &mean,
2745  const primitive::at &variance, const primitive::at &diff_dst,
2746  const primitive::at &weights, const primitive::at &workspace,
2747  const memory &diff_src, const memory &diff_weights) {
2748  mkldnn_primitive_t result;
2749  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2750  diff_dst.data, weights.data, workspace.data };
2751  const_mkldnn_primitive_t outputs[] = { diff_src.get(),
2752  diff_weights.get() };
2753  check_num_parameters(aprimitive_desc.get(), 6, 2,
2754  "batch normalization backward");
2756  aprimitive_desc.get(), inputs, outputs),
2757  "could not create a batch normalization backward primitive");
2758  reset(result);
2759  }
2760 
2761  // Prop_kind == backward_data (+ws or +weights)
2766  const primitive::at &src, const primitive::at &mean,
2767  const primitive::at &variance,const primitive::at &diff_dst,
2768  const primitive::at &weights_or_workspace, const memory &diff_src) {
2769  mkldnn_primitive_t result;
2770  mkldnn_primitive_at_t inputs[] = { src.data, mean.data, variance.data,
2771  diff_dst.data, weights_or_workspace.data };
2772  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2773  check_num_parameters(aprimitive_desc.get(), 5, 1,
2774  "batch normalization backward");
2776  aprimitive_desc.get(), inputs, outputs),
2777  "could not create a batch normalization backward primitive");
2778  reset(result);
2779  }
2780 
2781  // Prop_kind == backward_data
2783  const primitive::at &src, const primitive::at &mean,
2784  const primitive::at &variance, const primitive::at &diff_dst,
2785  const memory &diff_src) {
2786  mkldnn_primitive_t result;
2787  mkldnn_primitive_at_t inputs[] = { src.data,
2788  mean.data, variance.data, diff_dst.data };
2789  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2790  check_num_parameters(aprimitive_desc.get(), 4, 1,
2791  "batch normalization backward");
2793  aprimitive_desc.get(), inputs, outputs),
2794  "could not create a batch normalization backward primitive");
2795  reset(result);
2796  }
2797 };
2798 
2800 
2806 
2808  struct desc {
2810  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2811  const memory::desc &weights_desc,
2812  const memory::desc &bias_desc,
2813  const memory::desc &dst_desc) {
2816  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2817  &weights_desc.data, &bias_desc.data, &dst_desc.data),
2818  "could not create a inner product forward descriptor");
2819  }
2820 
2821  desc(prop_kind aprop_kind, const memory::desc &src_desc,
2822  const memory::desc &weights_desc,
2823  const memory::desc &dst_desc) {
2826  mkldnn::convert_to_c(aprop_kind), &src_desc.data,
2827  &weights_desc.data, nullptr, &dst_desc.data),
2828  "could not create a inner product forward descriptor");
2829  }
2830  };
2831 
2833  primitive_desc(const desc &desc, const engine &e)
2834  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
2835 
2836  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
2837  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
2838 
2839  REG_QUERY_MPD(src, src, 0);
2840  REG_QUERY_MPD(weights, weights, 0);
2841  REG_QUERY_MPD(bias, weights, 1);
2842  REG_QUERY_MPD(dst, dst, 0);
2843  };
2844 
2845  inner_product_forward(const primitive_desc &aprimitive_desc,
2846  const primitive::at &src, const primitive::at weights,
2847  const primitive::at &bias, const memory &dst) {
2848  mkldnn_primitive_t result;
2849  mkldnn_primitive_at_t inputs[] = { src.data, weights.data,
2850  bias.data };
2851  const_mkldnn_primitive_t outputs[] = { dst.get() };
2852  check_num_parameters(aprimitive_desc.get(), 3, 1,
2853  "inner product forward");
2855  aprimitive_desc.get(), inputs, outputs),
2856  "could not create a inner product forward primitive");
2857  reset(result);
2858  }
2859 
2860  inner_product_forward(const primitive_desc &aprimitive_desc,
2861  const primitive::at &src, const primitive::at weights,
2862  const memory &dst) {
2863  mkldnn_primitive_t result;
2864  mkldnn_primitive_at_t inputs[] = { src.data, weights.data };
2865  const_mkldnn_primitive_t outputs[] = { dst.get() };
2866  check_num_parameters(aprimitive_desc.get(), 2, 1,
2867  "inner product forward");
2869  aprimitive_desc.get(), inputs, outputs),
2870  "could not create a inner product forward primitive");
2871  reset(result);
2872  }
2873 };
2874 
2876  struct desc {
2878  desc(const memory::desc &diff_src_desc,
2879  const memory::desc &weights_desc,
2880  const memory::desc &diff_dst_desc) {
2883  &diff_src_desc.data, &weights_desc.data,
2884  &diff_dst_desc.data),
2885  "could not create a inner product backward data descriptor");
2886  }
2887  };
2888 
2890  primitive_desc(const desc &desc, const engine &e,
2891  const inner_product_forward::primitive_desc &hint_fwd_pd)
2892  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2893 
2894  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2895  const inner_product_forward::primitive_desc &hint_fwd_pd)
2896  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2897 
2898  REG_QUERY_MPD(diff_src, diff_src, 0);
2899  REG_QUERY_MPD(weights, weights, 0);
2900  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2901  };
2902 
2904  const primitive::at &diff_dst, const primitive::at weights,
2905  const memory &diff_src) {
2906  mkldnn_primitive_t result;
2907  mkldnn_primitive_at_t inputs[] = { diff_dst.data, weights.data };
2908  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
2909  check_num_parameters(aprimitive_desc.get(), 2, 1,
2910  "inner product backward data");
2912  aprimitive_desc.get(), inputs, outputs),
2913  "could not create a inner product backward data primitive");
2914  reset(result);
2915  }
2916 };
2917 
2919  struct desc {
2921  desc(const memory::desc &src_desc,
2922  const memory::desc &diff_weights_desc,
2923  const memory::desc &diff_bias_desc,
2924  const memory::desc &diff_dst_desc) {
2927  &data, &src_desc.data, &diff_weights_desc.data,
2928  &diff_bias_desc.data, &diff_dst_desc.data),
2929  "could not create a inner product backward weights descriptor");
2930  }
2931  desc(const memory::desc &src_desc,
2932  const memory::desc &diff_weights_desc,
2933  const memory::desc &diff_dst_desc) {
2936  &data, &src_desc.data, &diff_weights_desc.data,
2937  nullptr, &diff_dst_desc.data),
2938  "could not create a inner product backward weights descriptor");
2939  }
2940  };
2941 
2943  primitive_desc(const desc &desc, const engine &e,
2944  const inner_product_forward::primitive_desc &hint_fwd_pd)
2945  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
2946 
2947  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
2948  const inner_product_forward::primitive_desc &hint_fwd_pd)
2949  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
2950 
2951  REG_QUERY_MPD(src, src, 0);
2952  REG_QUERY_MPD(diff_weights, diff_weights, 0);
2953  REG_QUERY_MPD(diff_bias, diff_weights, 1);
2954  REG_QUERY_MPD(diff_dst, diff_dst, 0);
2955  };
2956 
2958  const primitive::at &src, const primitive::at diff_dst,
2959  const memory &diff_weights) {
2960  mkldnn_primitive_t result;
2961  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2962  const_mkldnn_primitive_t outputs[] = { diff_weights.get() };
2963  check_num_parameters(aprimitive_desc.get(), 2, 1,
2964  "inner product backward weights");
2966  aprimitive_desc.get(), inputs, outputs),
2967  "could not create a inner product backward weights primitive");
2968  reset(result);
2969  }
2970 
2972  const primitive::at &src, const primitive::at diff_dst,
2973  const memory &diff_weights, const memory &diff_bias) {
2974  mkldnn_primitive_t result;
2975  mkldnn_primitive_at_t inputs[] = { src.data, diff_dst.data };
2976  const_mkldnn_primitive_t outputs[] =
2977  { diff_weights.get(), diff_bias.get()};
2978  check_num_parameters(aprimitive_desc.get(), 2, 2,
2979  "inner product backward weights");
2981  aprimitive_desc.get(), inputs, outputs),
2982  "could not create a inner product backward weights primitive");
2983  reset(result);
2984  }
2985 };
2986 
2988 
2994 
2995 struct rnn_cell {
2996  struct desc {
2998 
2999  desc(algorithm kind, algorithm activation_f) {
3001  mkldnn::convert_to_c(kind),
3002  mkldnn::convert_to_c(activation_f), 0U, 0, 0),
3003  "could not init an rnn cell descriptor");
3004  }
3006 
3007  operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
3008 
3010  { return algorithm(c_rnn_cell_.cell_kind); }
3012  { return algorithm(c_rnn_cell_.activation_kind); }
3013 
3014  float get_alpha() const { return c_rnn_cell_.alpha; }
3015  void set_alpha(float alpha) {
3016  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
3017  c_rnn_cell_.alpha = alpha;
3018  }
3019 
3020  float get_clipping() const { return c_rnn_cell_.clipping; }
3021  void set_clipping(float clipping) {
3022  c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
3023  c_rnn_cell_.clipping = clipping;
3024  }
3025 
3026  int get_gates_count() const {
3027  return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
3028  }
3029  int get_state_count() const {
3030  return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
3031  }
3032  };
3033 };
3034 
3035 struct rnn_forward : public primitive {
3036  struct desc {
3038  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3039  const rnn_direction direction,
3040  const memory::desc &src_layer_desc,
3041  const memory::desc &src_iter_desc,
3042  const memory::desc &weights_layer_desc,
3043  const memory::desc &weights_iter_desc,
3044  const memory::desc &bias_desc,
3045  const memory::desc &dst_layer_desc,
3046  const memory::desc &dst_iter_desc
3047  ) {
3049  mkldnn::convert_to_c(aprop_kind), cell,
3050  mkldnn::convert_to_c(direction),
3051  &src_layer_desc.data, &src_iter_desc.data,
3052  &weights_layer_desc.data, &weights_iter_desc.data,
3053  &bias_desc.data,
3054  &dst_layer_desc.data, &dst_iter_desc.data),
3055  "could not create an RNN forward descriptor");
3056  }
3057 
3058  };
3059 
3061  primitive_desc(const desc &desc, const engine &e)
3062  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3063 
3064  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
3065  : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
3066 
3067  REG_QUERY_MPD(src_layer, src, 0);
3068  REG_QUERY_MPD(src_iter, src, 1);
3069  REG_QUERY_MPD(weights_layer, weights, 0);
3070  REG_QUERY_MPD(weights_iter, weights, 1);
3071  REG_QUERY_MPD(bias, weights, 2);
3072  REG_QUERY_MPD(dst_layer, dst, 0);
3073  REG_QUERY_MPD(dst_iter, dst, 1);
3074  REG_QUERY_MPD(workspace, workspace, 0);
3075  };
3076 
3077  rnn_forward(const primitive_desc &aprimitive_desc,
3078  const primitive::at &src_layer, const primitive::at &src_iter,
3079  const primitive::at &weights_layer,
3080  const primitive::at &weights_iter, const primitive::at &bias,
3081  const memory &dst_layer, const memory &dst_iter,
3082  const memory &workspace) {
3083  mkldnn_primitive_t result;
3084  mkldnn_primitive_at_t inputs[5];
3085  const_mkldnn_primitive_t outputs[3];
3086  int idx=0;
3087  inputs[idx++] = src_layer.data;
3088  if (!is_null_memory(src_iter.data.primitive))
3089  inputs[idx++] = src_iter.data;
3090  inputs[idx++] = weights_layer.data;
3091  inputs[idx++] = weights_iter.data;
3092  if (!is_null_memory(bias.data.primitive)) inputs[idx++] = bias.data;
3093 
3094  idx=0;
3095  outputs[idx++] = dst_layer.get();
3096  if (!is_null_memory(dst_iter.get())) outputs[idx++] = dst_iter.get();
3097  if (!is_null_memory(workspace.get())) outputs[idx++] = workspace.get();
3098 
3100  aprimitive_desc.get(), inputs, outputs),
3101  "could not create an RNN forward primitive");
3102  reset(result);
3103  }
3104 };
3105 
3106 struct rnn_backward : public primitive {
3107  struct desc {
3109  desc(prop_kind aprop_kind, rnn_cell::desc cell,
3110  const rnn_direction direction,
3111  const memory::desc &src_layer_desc,
3112  const memory::desc &src_iter_desc,
3113  const memory::desc &weights_layer_desc,
3114  const memory::desc &weights_iter_desc,
3115  const memory::desc &bias_desc,
3116  const memory::desc &dst_layer_desc,
3117  const memory::desc &dst_iter_desc,
3118  const memory::desc &diff_src_layer_desc,
3119  const memory::desc &diff_src_iter_desc,
3120  const memory::desc &diff_weights_layer_desc,
3121  const memory::desc &diff_weights_iter_desc,
3122  const memory::desc &diff_bias_desc,
3123  const memory::desc &diff_dst_layer_desc,
3124  const memory::desc &diff_dst_iter_desc) {
3126  mkldnn::convert_to_c(aprop_kind), cell,
3127  mkldnn::convert_to_c(direction),
3128  &src_layer_desc.data, &src_iter_desc.data,
3129  &weights_layer_desc.data, &weights_iter_desc.data,
3130  &bias_desc.data,
3131  &dst_layer_desc.data, &dst_iter_desc.data,
3132  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
3133  &diff_weights_layer_desc.data,
3134  &diff_weights_iter_desc.data, &diff_bias_desc.data,
3135  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
3136  "could not create an RNN backward descriptor");
3137  }
3138 
3139  };
3140 
3142  primitive_desc(const desc &desc, const engine &e,
3143  const rnn_forward::primitive_desc &hint_fwd_pd)
3144  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3145 
3146  primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
3147  const rnn_forward::primitive_desc &hint_fwd_pd)
3148  : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
3149 
3150  REG_QUERY_MPD(src_layer, src, 0);
3151  REG_QUERY_MPD(src_iter, src, 1);
3152  REG_QUERY_MPD(weights_layer, weights, 0);
3153  REG_QUERY_MPD(weights_iter, weights, 1);
3154  REG_QUERY_MPD(bias, weights, 2);
3155  REG_QUERY_MPD(dst_layer, dst, 0);
3156  REG_QUERY_MPD(dst_iter, dst, 1);
3157  REG_QUERY_MPD(workspace, workspace, 0);
3158 
3159  REG_QUERY_MPD(diff_src_layer, diff_src, 0);
3160  REG_QUERY_MPD(diff_src_iter, diff_src, 1);
3161  REG_QUERY_MPD(diff_weights_layer, diff_weights, 0);
3162  REG_QUERY_MPD(diff_weights_iter, diff_weights, 1);
3163  REG_QUERY_MPD(diff_bias, diff_weights, 2);
3164  REG_QUERY_MPD(diff_dst_layer, diff_dst, 0);
3165  REG_QUERY_MPD(diff_dst_iter, diff_dst, 1);
3166  };
3167 
3168  // With last iteration (with and without input src_iter)
3169  rnn_backward(const primitive_desc &aprimitive_desc,
3170  const primitive::at &src_layer,
3171  const primitive::at &src_iter,
3172  const primitive::at &weights_layer,
3173  const primitive::at &weights_iter,
3174  const primitive::at &bias,
3175  const primitive::at &dst_layer,
3176  const primitive::at &dst_iter,
3177  const memory &diff_src_layer,
3178  const memory &diff_src_iter,
3179  const memory &diff_weights_layer,
3180  const memory &diff_weights_iter,
3181  const memory &diff_bias,
3182  const primitive::at &diff_dst_layer,
3183  const primitive::at &diff_dst_iter,
3184  const primitive::at &workspace) {
3185  mkldnn_primitive_t result;
3186  mkldnn_primitive_at_t inputs[10];
3187  const_mkldnn_primitive_t outputs[5];
3188  int idx=0;
3189  inputs[idx++] = src_layer.data;
3190  if (!is_null_memory(src_iter.data.primitive))
3191  inputs[idx++] = src_iter.data;
3192  inputs[idx++] = weights_layer.data;
3193  inputs[idx++] = weights_iter.data;
3194  if (!is_null_memory(bias.data.primitive))
3195  inputs[idx++] = bias.data;
3196  inputs[idx++] = dst_layer.data;
3197  if (!is_null_memory(dst_iter.data.primitive))
3198  inputs[idx++] = dst_iter.data;
3199  inputs[idx++] = diff_dst_layer.data;
3200  if (!is_null_memory(diff_dst_iter.data.primitive))
3201  inputs[idx++] = diff_dst_iter.data;
3202  inputs[idx++] = workspace.data;
3203 
3204  idx = 0;
3205  outputs[idx++] = diff_src_layer.get();
3206  if (!is_null_memory(diff_src_iter.get()))
3207  outputs[idx++] = diff_src_iter.get();
3208  outputs[idx++] = diff_weights_layer.get();
3209  outputs[idx++] = diff_weights_iter.get();
3210  if (!is_null_memory(diff_bias.get())) outputs[idx++] = diff_bias.get();
3212  aprimitive_desc.get(), inputs, outputs),
3213  "could not create an RNN backward primitive");
3214  reset(result);
3215  }
3216 };
3217 
3219 
3225 
3226 struct shuffle_forward : public primitive {
3227  struct desc {
3229  desc(prop_kind aprop_kind, const memory::desc &data_desc,
3230  int axis, int group_size) {
3232  mkldnn::convert_to_c(aprop_kind), &data_desc.data,
3233  axis, group_size),
3234  "could not create a shuffle forward descriptor");
3235  }
3236  };
3237 
3239  primitive_desc(const desc &desc, const engine &e)
3240  : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
3241 
3242  REG_QUERY_MPD(src, src, 0);
3243  REG_QUERY_MPD(dst, dst, 0);
3244  };
3245 
3246  shuffle_forward(const primitive_desc &aprimitive_desc,
3247  const primitive::at &src, const memory &dst) {
3248  mkldnn_primitive_t result;
3249  mkldnn_primitive_at_t inputs[] = { src.data };
3250  const_mkldnn_primitive_t outputs[] = { dst.get() };
3251  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle forward");
3253  aprimitive_desc.get(), inputs, outputs),
3254  "could not create a shuffle forward primitive");
3255  reset(result);
3256  }
3257 };
3258 
3259 struct shuffle_backward : public primitive {
3260  struct desc {
3262  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
3264  &diff_data_desc.data, axis, group_size),
3265  "could not create a shuffle backward descriptor");
3266  }
3267  };
3268 
3270  primitive_desc(const desc &desc, const engine &e,
3271  const shuffle_forward::primitive_desc &hint_fwd_pd)
3272  : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
3273 
3274  REG_QUERY_MPD(diff_src, diff_src, 0);
3275  REG_QUERY_MPD(diff_dst, diff_dst, 0);
3276  };
3277 
3278  shuffle_backward(const primitive_desc &aprimitive_desc,
3279  const primitive::at &diff_dst, const memory &diff_src) {
3280  mkldnn_primitive_t result;
3281  mkldnn_primitive_at_t inputs[] = { diff_dst.data};
3282  const_mkldnn_primitive_t outputs[] = { diff_src.get() };
3283  check_num_parameters(aprimitive_desc.get(), 1, 1, "shuffle backward");
3285  aprimitive_desc.get(), inputs, outputs),
3286  "could not create a shuffle backward primitive");
3287  reset(result);
3288  }
3289 };
3290 
3292 
3294 
3300 
3301 #ifndef DOXYGEN_SHOULD_SKIP_THIS
3302 template <> struct handle_traits<mkldnn_stream_t> {
3303  static constexpr auto destructor = &mkldnn_stream_destroy;
3304 };
3305 #endif
3306 
3307 struct stream: public handle<mkldnn_stream_t> {
3308  using handle::handle;
3309 
3313 
3315  return static_cast<mkldnn_stream_kind_t>(akind);
3316  }
3318  stream(kind akind) {
3319  mkldnn_stream_t astream;
3321  convert_to_c(akind)),
3322  "could not create a stream");
3323  reset(astream);
3324  }
3325 
3330  stream &submit(std::vector<primitive> primitives) {
3331  // TODO: find a proper way to convert vector<primitive> to
3332  // vector<mkldnn_primitive_t>
3333  if (primitives.size() == 0) return *this;
3334  std::vector<mkldnn_primitive_t> c_api_primitives;
3335  c_api_primitives.reserve(primitives.size());
3336  auto convert_to_c = [](primitive p) { return p.get(); };
3337  std::transform(primitives.begin(), primitives.end(),
3338  std::back_inserter(c_api_primitives), convert_to_c);
3339 
3340  mkldnn_primitive_t c_api_error_primitive;
3342  mkldnn_stream_submit(get(),
3343  c_api_primitives.size(), &c_api_primitives[0],
3344  &c_api_error_primitive),
3345  "could not submit primitives to a stream",
3346  &c_api_error_primitive);
3347 
3348  return *this;
3349  }
3350 
3357  bool wait(bool block = true) {
3358  mkldnn_primitive_t c_api_error_primitive;
3359  mkldnn_status_t status = mkldnn_stream_wait(get(),
3360  block, &c_api_error_primitive);
3361  if (status != mkldnn_success
3362  && status != mkldnn_try_again)
3363  error::wrap_c_api(status, "could not wait on a stream",
3364  &c_api_error_primitive);
3365  return (status == mkldnn_success);
3366  }
3367 
3369  mkldnn_primitive_t c_api_error_primitive;
3371  mkldnn_stream_rerun(get(), &c_api_error_primitive),
3372  "could not rerun a stream", &c_api_error_primitive);
3373  return *this;
3374  }
3375 };
3376 
3377 #undef REG_QUERY_MPD
3378 
3380 
3382 
3383 } // namespace mkldnn
3384 
3385 #endif
void append_sum(float scale=1.)
Definition: mkldnn.hpp:386
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2422
Definition: mkldnn.hpp:2373
LRN within a single channel.
Definition: mkldnn_types.h:567
primitive error_primitive
Definition: mkldnn.hpp:164
A descriptor of a Local Response Normalization (LRN) operation.
Definition: mkldnn_types.h:906
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1521
Definition: mkldnn.hpp:343
blocked weights format
Definition: mkldnn_types.h:348
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const memory &dst)
Definition: mkldnn.hpp:2860
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2214
Definition: mkldnn.hpp:269
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1103
blocked weights format
Definition: mkldnn_types.h:355
op descriptor
Definition: mkldnn_types.h:1248
primitive_desc(const memory::desc &output, int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1113
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1664
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:393
Definition: mkldnn.hpp:3106
blocked weights format
Definition: mkldnn_types.h:332
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(mkldnn_primitive_attr_t attr)
Deletes an attr.
blocked weights format
Definition: mkldnn_types.h:433
mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(mkldnn_primitive_desc_t *sum_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, const float *scales, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place sum_primitive_desc for sum of n inputs multiplied by scale with resulting output...
Definition: mkldnn.hpp:257
A Softmax primitive.
Definition: mkldnn_types.h:509
number of outputs expected
Definition: mkldnn_types.h:1237
bool operator!=(const handle &other) const
Definition: mkldnn.hpp:88
mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream)
Destroys an execution stream.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:3064
blocked weights format
Definition: mkldnn_types.h:438
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:1674
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2534
stream & submit(std::vector< primitive > primitives)
Submits a vector of primitives to a stream for computations.
Definition: mkldnn.hpp:3330
bool operator==(const primitive_desc &other) const
Definition: mkldnn.hpp:821
A base class for all primitive descriptors.
Definition: mkldnn.hpp:1270
Definition: mkldnn.hpp:2247
mkldnn_status_t
Status values returned by Intel(R) MKL-DNN functions.
Definition: mkldnn_types.h:49
stream & rerun()
Definition: mkldnn.hpp:3368
Definition: mkldnn.hpp:2210
A descriptor of a convolution operation.
Definition: mkldnn_types.h:758
Definition: mkldnn.hpp:301
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3229
Definition: mkldnn.hpp:2185
The operation failed and should be retried.
Definition: mkldnn_types.h:55
memory null_memory(engine eng)
Definition: mkldnn.hpp:917
mkldnn_status_t MKLDNN_API mkldnn_memory_primitive_desc_create(mkldnn_primitive_desc_t *memory_primitive_desc, const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine)
Creates a memory_primitive_desc memory primitive descriptor using memory_desc and engine...
blocked weights format
Definition: mkldnn_types.h:292
mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops)
Creates an empty sequence of post operations post_ops.
Definition: mkldnn.hpp:330
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(mkldnn_primitive_desc_t primitive_desc)
Deletes a primitive_desc.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1614
mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(mkldnn_primitive_desc_t *concat_primitive_desc, const mkldnn_memory_desc_t *output_desc, int n, int concat_dimension, const_mkldnn_primitive_desc_t *input_pds)
Creates out-of-place concat_primitive_desc for concatenation of n inputs by concat_dimension with res...
4D RNN bias tensor in the format (num_layers, num_directions, num_gates, output_channels).
Definition: mkldnn_types.h:266
4D data tensor with the physical layout chwn, used in Neon.
Definition: mkldnn_types.h:175
Definition: mkldnn.hpp:265
padding_kind
Definition: mkldnn.hpp:232
The operation failed because of incorrect function arguments.
Definition: mkldnn_types.h:57
Eltwise: exponent.
Definition: mkldnn_types.h:556
Forward data propagation (alias for mkldnn_forward_inference)
Definition: mkldnn_types.h:470
Definition: mkldnn.hpp:2048
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1576
Backward data propagation.
Definition: mkldnn_types.h:476
Definition: mkldnn.hpp:2446
static void validate_dims(std::vector< T > v)
Definition: mkldnn.hpp:587
Definition: mkldnn.hpp:3269
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(const_mkldnn_primitive_desc_t primitive_desc, const_mkldnn_primitive_attr_t *attr)
Returns a constant reference to the attribute of a primitive_desc.
Definition: mkldnn.hpp:3259
mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init(mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims, mkldnn_data_type_t data_type, mkldnn_memory_format_t format)
Initializes a memory_desc memory descriptor using ndims, dims, data_type, and data format...
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2412
Definition: mkldnn.hpp:275
blocked weights format
Definition: mkldnn_types.h:326
blocked weights format
Definition: mkldnn_types.h:404
Undefined memory format, used for empty memory descriptors.
Definition: mkldnn_types.h:149
const_mkldnn_primitive_desc_t get_primitive_desc() const
Returns the descriptor of the underlying C API primitive.
Definition: mkldnn.hpp:210
concat(const primitive_desc &concat_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1154
memory::desc desc()
Returns the memory primitive descriptor.
Definition: mkldnn.hpp:811
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2009
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_weights_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to weights using...
float alpha
alpha is a negative slope parameter (used only if (flags & mkldnn_rnn_cell_with_relu) != 0) ...
Definition: mkldnn_types.h:1010
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(mkldnn_primitive_attr_t *attr, const_mkldnn_primitive_attr_t existing_attr)
Makes a copy of an existing_attr.
#define TENSOR_MAX_DIMS
Maximum number of dimensions a tensor can have.
Definition: mkldnn_types.h:632
format
Memory format specification. See mkldnn_memory_format_t for a detailed description.
Definition: mkldnn.hpp:607
Definition: mkldnn.hpp:291
4D weights tensor with physical layout oihw, used in Caffe.
Definition: mkldnn_types.h:199
A descriptor of a Softmax operation.
Definition: mkldnn_types.h:856
blocked weights format
Definition: mkldnn_types.h:439
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(mkldnn_primitive_desc_t *primitive_desc, const_mkldnn_primitive_desc_t existing_primitive_desc)
Makes a copy of a primitive_desc.
softmax_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2432
blocked weights format
Definition: mkldnn_types.h:440
blocked weights format
Definition: mkldnn_types.h:403
Definition: mkldnn.hpp:272
blocked data format
Definition: mkldnn_types.h:275
mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(const_mkldnn_primitive_t memory, void **handle)
For a memory primitive, returns the data handle.
Definition: mkldnn.hpp:244
mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for backward propagation with respect to data using al...
A descriptor of an inner product operation.
Definition: mkldnn_types.h:964
mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops)
Deletes a post_ops sequence.
std::vector< std::remove_extent< mkldnn_dims_t >::type > dims
Definition: mkldnn.hpp:585
3D RNN data tensor in the format (seq_length, batch, input channels).
Definition: mkldnn_types.h:242
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3239
An opaque structure for a chain of post operations.
An opaque structure to describe a primitive descriptor.
batch normalization descriptor
Definition: mkldnn_types.h:1257
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1733
mkldnn_rnn_direction_t
A direction of RNN primitive execution.
Definition: mkldnn_types.h:1017
void reset(T t, bool weak=false)
Resets the value of a C handle.
Definition: mkldnn.hpp:79
A convolution primitive.
Definition: mkldnn_types.h:503
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1881
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2111
mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(mkldnn_primitive_t memory, void *handle)
For a memory primitive, sets the data handle.
engine(const mkldnn_engine_t &aengine)
Definition: mkldnn.hpp:539
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:301
engine(const handle< mkldnn_primitive_desc_t > &pd)
Definition: mkldnn.hpp:542
engine get_engine()
Definition: mkldnn.hpp:1283
desc(dims adims, data_type adata_type, format aformat)
Constructs a memory descriptor.
Definition: mkldnn.hpp:777
blocked data format
Definition: mkldnn_types.h:276
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for forward propagation using prop_kind (possi...
Definition: mkldnn.hpp:225
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2809
sum(const primitive_desc &sum_pd, std::vector< primitive::at > &inputs, const memory &output)
Definition: mkldnn.hpp:1243
An execution engine.
Definition: mkldnn.hpp:504
memory(const primitive_desc &adesc, void *ahandle)
Definition: mkldnn.hpp:867
blocked weights format
Definition: mkldnn_types.h:429
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2877
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg, float alpha, float beta)
Appends eltwise post operation to the post_ops with given parameters kind, alpha, and beta (...
static void wrap_c_api(mkldnn_status_t status, const std::string &message, mkldnn_primitive_t *error_primitive=0)
A convenience function for wrapping calls to the C API. Checks the return status and throws an error ...
Definition: mkldnn.hpp:188
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2249
blocked weights format
Definition: mkldnn_types.h:339
Undefined primitive (XXX: why do we have it?).
Definition: mkldnn_types.h:487
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to data using ...
An inner product primitive.
Definition: mkldnn_types.h:517
Packed weights format used in RNN.
Definition: mkldnn_types.h:444
void check_num_parameters(const const_mkldnn_primitive_desc_t &aprimitive_desc, int n_inputs, int n_outputs, const std::string &prim_name)
Definition: mkldnn.hpp:922
Round down.
Definition: mkldnn_types.h:94
4D grouped weights tensor with the physical layout goiw.
Definition: mkldnn_types.h:223
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2461
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1751
Definition: mkldnn.hpp:264
round_mode get_int_output_round_mode() const
Definition: mkldnn.hpp:427
blocked weights format
Definition: mkldnn_types.h:435
blocked weights format
Definition: mkldnn_types.h:294
primitive_attr()
Definition: mkldnn.hpp:420
Definition: mkldnn_types.h:563
Definition: mkldnn.hpp:2358
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams(mkldnn_primitive_attr_t attr, int count, int mask, const float *weights_scales)
Sets quantization scales weights_scales for RNN weights tensors.
mkldnn_primitive_at_t MKLDNN_API mkldnn_primitive_at(const_mkldnn_primitive_t primitive, size_t output_index)
Creates an mkldnn_primitive_at_t structure from a primitive and output_index.
primitive_desc(const desc &desc, const engine &e, const softmax_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2457
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2447
Definition: mkldnn.hpp:2421
void get_params_sum(int index, float &scale) const
Definition: mkldnn.hpp:391
Definition: mkldnn.hpp:247
32-bit signed integer.
Definition: mkldnn_types.h:78
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2890
Max pooling.
Definition: mkldnn_types.h:558
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1435
memory::desc zero_md()
Definition: mkldnn.hpp:911
Definition: mkldnn.hpp:337
primitive_desc(const memory::primitive_desc &input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1046
mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for forward propagation using prop_kind (possible values are mkldnn_forwar...
blocked weights format
Definition: mkldnn_types.h:314
blocked weights format
Definition: mkldnn_types.h:338
const post_ops get_post_ops() const
Definition: mkldnn.hpp:461
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims kernel, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2187
execution engine
Definition: mkldnn_types.h:1233
stream(kind akind)
Constructs a stream.
Definition: mkldnn.hpp:3318
Definition: mkldnn.hpp:1045
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(mkldnn_primitive_desc_iterator_t iterator)
Iterates over primitive descriptors.
Definition: mkldnn.hpp:336
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2878
mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for backward propagation using alg_kind, memory descriptors, and pooling parameters in the spatial domain: strides, kernel sizes, padding_l, padding_r, and padding_kind.
Definition: mkldnn.hpp:2184
blocked weights format
Definition: mkldnn_types.h:322
static mkldnn_memory_format_t convert_to_c(format aformat)
Definition: mkldnn.hpp:906
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2378
Definition: mkldnn.hpp:321
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(mkldnn_primitive_attr_t *attr)
Creates an empty (default) attr attribute.
Definition: mkldnn_types.h:995
mkldnn_status_t MKLDNN_API mkldnn_stream_submit(mkldnn_stream_t stream, size_t n, mkldnn_primitive_t primitives[], mkldnn_primitive_t *error_primitive)
Submits primitives to an execution stream.
algorithm
Definition: mkldnn.hpp:255
input memory primitive desc
Definition: mkldnn_types.h:1263
blocked weights format
Definition: mkldnn_types.h:341
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3228
5D grouped weights tensor with the physical layout goihw, used in Caffe.
Definition: mkldnn_types.h:227
const_mkldnn_primitive_t primitive
Primitive to specify the output for.
Definition: mkldnn_types.h:1193
Definition: mkldnn.hpp:290
blocked weights format
Definition: mkldnn_types.h:354
rnn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const memory &dst_layer, const memory &dst_iter, const memory &workspace)
Definition: mkldnn.hpp:3077
mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(mkldnn_rnn_cell_desc_t *rnn_cell_desc, mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f, unsigned int flags, float alpha, float clipping)
Initializes a recurrent cell descriptor rnn_cell_desc using rnn_cell_desc, kind (possible values are ...
A descriptor of a element-wise operation.
Definition: mkldnn_types.h:820
rnn descriptor
Definition: mkldnn_types.h:1259
memory::primitive_desc variance_primitive_desc() const
Definition: mkldnn.hpp:2520
An element-wise primitive.
Definition: mkldnn_types.h:507
Definition: mkldnn.hpp:2445
blocked weights format
Definition: mkldnn_types.h:331
destination grad.
Definition: mkldnn_types.h:1270
algorithm get_cell_kind() const
Definition: mkldnn.hpp:3009
engine get_engine()
Definition: mkldnn.hpp:1240
Definition: mkldnn.hpp:2359
mkldnn_status_t MKLDNN_API mkldnn_stream_wait(mkldnn_stream_t stream, int block, mkldnn_primitive_t *error_primitive)
Waits for all primitives in the execution stream to finish.
mkldnn_alg_kind_t activation_kind
Activation function used.
Definition: mkldnn_types.h:1005
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1227
blocked weights format
Definition: mkldnn_types.h:344
A descriptor for an RNN operation.
Definition: mkldnn_types.h:1032
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1412
Definition: mkldnn.hpp:1101
Definition: mkldnn.hpp:278
Definition: mkldnn.hpp:259
eltwise descriptor
Definition: mkldnn_types.h:1253
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2628
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1460
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(mkldnn_primitive_attr_t attr, const float scale, const float shift)
Sets quantization scale and shift for RNN data tensors.
Definition: mkldnn.hpp:277
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights_or_workspace, const memory &diff_src)
Definition: mkldnn.hpp:2765
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2096
size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind)
Returns the number of engines of a particular kind.
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2921
batch_normalization_flag
Definition: mkldnn.hpp:289
A memory primitive.
Definition: mkldnn_types.h:489
float clipping
clipping parameter (used only if (flags & mkldnn_rnn_cell_with_clipping) != 0)
Definition: mkldnn_types.h:1013
blocked weights format
Definition: mkldnn_types.h:311
blocked weights format
Definition: mkldnn_types.h:325
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc)
Definition: mkldnn.hpp:3109
Eltwise: soft_relu.
Definition: mkldnn_types.h:552
void set_post_ops(post_ops ops)
Definition: mkldnn.hpp:470
inner_product_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:2845
Definition: mkldnn.hpp:342
Definition: mkldnn.hpp:261
mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(const_mkldnn_post_ops_t post_ops, int index)
Returns the type of post operation with index index in given post_ops.
RNN cell.
Definition: mkldnn_types.h:569
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2211
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1772
bool is_null_memory(const const_mkldnn_primitive_t &aprimitive)
Definition: mkldnn.hpp:942
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2894
Definition: mkldnn.hpp:368
blocked weights format
Definition: mkldnn_types.h:362
bool operator==(const handle &other) const
Definition: mkldnn.hpp:87
Definition: mkldnn.hpp:1372
Backward weights propagation.
Definition: mkldnn_types.h:478
void set_int_output_round_mode(round_mode mode)
Definition: mkldnn.hpp:434
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3037
blocked weights format
Definition: mkldnn_types.h:432
32-bit/single-precision floating point.
Definition: mkldnn_types.h:76
blocked weights format
Definition: mkldnn_types.h:288
blocked data format
Definition: mkldnn_types.h:273
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1596
algorithm get_activation() const
Definition: mkldnn.hpp:3011
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2222
2D weights tensor with physical layout oi.
Definition: mkldnn_types.h:184
Just a sentinel, not real memory format.
Definition: mkldnn_types.h:448
Memory descriptor.
Definition: mkldnn_types.h:717
Definition: mkldnn.hpp:2808
Definition: mkldnn.hpp:304
blocked weights format
Definition: mkldnn_types.h:350
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to data using m...
Base class for all computational primitives.
Definition: mkldnn.hpp:106
shuffle_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:3246
mkldnn_batch_normalization_flag_t
Flags for batch-normalization primititve.
Definition: mkldnn_types.h:586
void set_clipping(float clipping)
Definition: mkldnn.hpp:3021
convolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:1688
mkldnn_lrn_desc_t data
Definition: mkldnn.hpp:2049
Definition: mkldnn.hpp:2807
desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2496
Definition: mkldnn.hpp:281
pooling descriptor
Definition: mkldnn_types.h:1255
Definition: mkldnn.hpp:2248
const mkldnn_memory_desc_t MKLDNN_API * mkldnn_primitive_desc_query_memory_d(const_mkldnn_primitive_desc_t primitive_desc)
Queries primitive descriptor for memory descriptor.
prop_kind
Definition: mkldnn.hpp:240
mkldnn_pooling_desc_t data
Definition: mkldnn.hpp:2186
Definition: mkldnn.hpp:267
blocked weights format
Definition: mkldnn_types.h:287
blocked data format
Definition: mkldnn_types.h:277
3D weights tensor with physical layout wio.
Definition: mkldnn_types.h:196
blocked weights format
Definition: mkldnn_types.h:414
blocked weights format
Definition: mkldnn_types.h:361
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor deconv_desc for forward propagation using prop_kind (p...
unsigned int flags
RNN cell flags.
Definition: mkldnn_types.h:1007
3D data tensor with the physical layout ncw.
Definition: mkldnn_types.h:163
blocked weights format
Definition: mkldnn_types.h:329
convolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1558
The operation was successful.
Definition: mkldnn_types.h:51
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:424
mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine, mkldnn_engine_kind_t kind, size_t index)
Creates an engine of particular kind and index.
blocked weights format
Definition: mkldnn_types.h:386
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2947
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1660
desc(algorithm kind, algorithm activation_f)
Definition: mkldnn.hpp:2999
blocked weights format
Definition: mkldnn_types.h:400
Definition: mkldnn.hpp:327
Definition: mkldnn.hpp:245
primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr, const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd)
Definition: mkldnn.hpp:1271
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_int_output_round_mode(const_mkldnn_primitive_attr_t attr, mkldnn_round_mode_t *round_mode)
Returns integer output rounding mode round_mode for a given attr, previously set by mkldnn_primitive_...
blocked weights format
Definition: mkldnn_types.h:430
mkldnn_rnn_desc_t data
Definition: mkldnn.hpp:3108
Backward propagation (with respect to all parameters.
Definition: mkldnn_types.h:474
5D data tensor with the physical layout ndhwc, used in TensorFlow.
Definition: mkldnn_types.h:181
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights, const memory &diff_bias)
Definition: mkldnn.hpp:2971
softmax descriptor
Definition: mkldnn_types.h:1254
mkldnn_round_mode_t
Rounding mode.
Definition: mkldnn_types.h:90
A deconvolution primitive.
Definition: mkldnn_types.h:505
Definition: mkldnn.hpp:331
Definition: mkldnn.hpp:276
primitive_desc(const desc &adesc, const engine &aengine)
Constructs a memory primitive descriptor.
Definition: mkldnn.hpp:801
Use global statistics.
Definition: mkldnn_types.h:599
Definition: mkldnn.hpp:31
primitive_desc(int concat_dimension, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1126
blocked weights format
Definition: mkldnn_types.h:330
no query
Definition: mkldnn_types.h:1231
Definition: mkldnn.hpp:1712
blocked weights format
Definition: mkldnn_types.h:416
blocked weights format
Definition: mkldnn_types.h:346
blocked weights format
Definition: mkldnn_types.h:365
mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a convolution descriptor conv_desc for forward propagation using prop_kind (possible valu...
mkldnn_status_t MKLDNN_API mkldnn_view_primitive_desc_create(mkldnn_primitive_desc_t *view_primitive_desc, const_mkldnn_primitive_desc_t memory_primitive_desc, const mkldnn_dims_t dims, const mkldnn_dims_t offsets)
Creates a view_primitive_desc for a given memory_primitive_desc, with dims sizes and offsets offsets...
8-bit unsigned integer.
Definition: mkldnn_types.h:84
blocked weights format
Definition: mkldnn_types.h:428
Definition: mkldnn.hpp:347
Average pooling include padding.
Definition: mkldnn_types.h:560
Unspecified format.
Definition: mkldnn_types.h:152
inner_product_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at weights, const memory &diff_src)
Definition: mkldnn.hpp:2903
Definition: mkldnn.hpp:2070
destination memory primitive desc
Definition: mkldnn_types.h:1269
memory::primitive_desc mean_primitive_desc() const
Definition: mkldnn.hpp:2518
5D RNN weights tensor in the format (num_layers, num_directions, input_channels, num_gates, output_channels).
Definition: mkldnn_types.h:252
GRU cell with linear before reset.
Definition: mkldnn_types.h:582
memory(const primitive_desc &adesc)
Constructs a memory primitive.
Definition: mkldnn.hpp:840
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2148
mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc, int axis, int group_size)
Initializes a shuffle_desc for forward propagation using prop_kind, memory descriptor data_desc...
Local response normalization (LRN) across multiple channels.
Definition: mkldnn_types.h:565
blocked weights format
Definition: mkldnn_types.h:310
GRU cell.
Definition: mkldnn_types.h:573
Eager stream.
Definition: mkldnn_types.h:1284
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output, const primitive_attr &aattr)
Definition: mkldnn.hpp:996
void set_output_scales(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:454
at(const primitive &aprimitive, size_t at=0)
Constructs a wrapper specifying aprimitive output with index at.
Definition: mkldnn.hpp:143
implementation name
Definition: mkldnn_types.h:1244
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1950
Definition: mkldnn.hpp:1373
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Definition: mkldnn.hpp:3262
Definition: mkldnn.hpp:3260
Definition: mkldnn.hpp:256
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2286
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(const_mkldnn_primitive_attr_t attr, int *count, int *mask, const float **scales)
Returns count, correspondence scale mask, and a pointer to a constant floating point array of output ...
3D weights tensor with physical layout oiw.
Definition: mkldnn_types.h:190
Eltwise: parametric exponential linear unit (elu)
Definition: mkldnn_types.h:540
kind
Kinds of engines.
Definition: mkldnn.hpp:509
Definition: mkldnn.hpp:2110
Definition: mkldnn.hpp:2875
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2425
Intel(R) MKL-DNN exception class.
Definition: mkldnn.hpp:161
round_mode
Definition: mkldnn.hpp:223
bool operator==(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:951
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1839
Eltwise: ReLU.
Definition: mkldnn_types.h:536
Definition: mkldnn.hpp:2409
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1374
Definition: mkldnn.hpp:233
1D data tensor.
Definition: mkldnn_types.h:158
mkldnn_primitive_at_t data
The underlying C API structure.
Definition: mkldnn.hpp:136
memory::primitive_desc query_mpd(query what, int idx=0) const
Queries and returns requested memory primitive descriptor.
Definition: mkldnn.hpp:1324
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2707
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops)
Sets configured post_ops to an attribute attr for future use (when primitive descriptor is being crea...
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3146
primitive_desc(const desc &desc, const engine &e, const shuffle_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3270
4D weights tensor with physical layout ihwo.
Definition: mkldnn_types.h:208
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2360
mkldnn_memory_format_t
Memory format specification.
Definition: mkldnn_types.h:147
Definition: mkldnn.hpp:1044
Eltwise: square.
Definition: mkldnn_types.h:542
blocked weights format
Definition: mkldnn_types.h:323
Definition: mkldnn.hpp:1178
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1394
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1056
Definition: mkldnn.hpp:282
mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for forward propagation using prop_kind (possible values are mkldnn_forwa...
int MKLDNN_API mkldnn_memory_primitive_desc_equal(const_mkldnn_primitive_desc_t lhs, const_mkldnn_primitive_desc_t rhs)
Compares two descriptors of memory primitives.
void set_rnn_data_qparams(const float scale, const float shift)
Definition: mkldnn.hpp:475
static mkldnn_data_type_t convert_to_c(data_type adata_type)
Definition: mkldnn.hpp:903
4D data tensor with the physical layout nhwc, used in TensorFlow.
Definition: mkldnn_types.h:172
void set_data_handle(void *handle) const
Definition: mkldnn.hpp:897
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2603
Definition: mkldnn.hpp:268
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2112
Backward bias propagation.
Definition: mkldnn_types.h:480
Definition: mkldnn.hpp:985
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2059
blocked weights format
Definition: mkldnn_types.h:425
Use scale and shift parameters.
Definition: mkldnn_types.h:612
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1714
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor deconv_desc for forward propagation using prop_kind (possible ...
query
Definition: mkldnn.hpp:312
Definition: mkldnn.hpp:280
weights format with additional buffer size equal to the number of output channels multiplied by numbe...
Definition: mkldnn_types.h:384
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index, void *result)
Queries primitive descriptor.
float get_alpha() const
Definition: mkldnn.hpp:3014
blocked weights format
Definition: mkldnn_types.h:309
blocked weights format
Definition: mkldnn_types.h:402
A descriptor of a shuffle operation.
Definition: mkldnn_types.h:803
void get_params_eltwise(int index, float &scale, algorithm &alg, float &alpha, float &beta) const
Definition: mkldnn.hpp:403
Definition: mkldnn_types.h:1027
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to wei...
mkldnn_eltwise_desc_t data
Definition: mkldnn.hpp:2322
primitive_desc(const desc &desc, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1995
Definition: mkldnn.hpp:419
blocked weights format
Definition: mkldnn_types.h:419
blocked weights format
Definition: mkldnn_types.h:357
int get_gates_count() const
Definition: mkldnn.hpp:3026
int ndims
Number of dimensions.
Definition: mkldnn_types.h:722
reorder(const primitive_desc &aprimitive_desc, const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:1009
Definition: mkldnn.hpp:2047
Definition: mkldnn.hpp:1102
kind
A proxy to C primitive kind enum.
Definition: mkldnn.hpp:113
blocked weights format with additional buffer with size equal to the number of groups and containing ...
Definition: mkldnn_types.h:377
5D grouped weights tensor with the physical layout giohw.
Definition: mkldnn_types.h:234
void set_alpha(float alpha)
Definition: mkldnn.hpp:3015
mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float alpha, float beta)
Initializes an eltwise_desc for backward propagation using alg_kind algorithm memory descriptors diff...
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, int local_size, float alpha, float beta)
Definition: mkldnn.hpp:2122
5D data tensor with the physical layout ncdhw.
Definition: mkldnn_types.h:178
Definition: mkldnn.hpp:3227
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(mkldnn_primitive_desc_iterator_t iterator)
Deletes a primitive descriptor iterator.
5D RNN states tensor in the format (num_layers, num_directions, num_states, batch, state channels).
Definition: mkldnn_types.h:245
Definition: mkldnn.hpp:2134
size_t get_size() const
Returns the number of bytes required to allocate the memory described including the padding area...
Definition: mkldnn.hpp:817
mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(mkldnn_post_ops_t post_ops, float scale)
Appends accumulation (sum) post operation to the post_ops.
Definition: mkldnn.hpp:1573
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1807
A rnn primitive.
Definition: mkldnn_types.h:519
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_output(const_mkldnn_primitive_t primitive, size_t index, const_mkldnn_primitive_t *output)
For a primitive, returns output at the index position.
blocked weights format
Definition: mkldnn_types.h:340
blocked weights format
Definition: mkldnn_types.h:283
mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(mkldnn_shuffle_desc_t *shuffle_desc, const mkldnn_memory_desc_t *diff_data_desc, int axis, int group_size)
Initializes a shuffle_desc for backward propagation using memory descriptor diff_data_desc, axis, and group_size.
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1911
Definition: mkldnn.hpp:2996
eltwise_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2387
mkldnn_prop_kind_t
Kinds of propagation.
Definition: mkldnn_types.h:458
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn.hpp:134
CPU engine.
Definition: mkldnn_types.h:1083
Definition: mkldnn.hpp:292
desc(algorithm alg_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2363
Eltwise: square root.
Definition: mkldnn_types.h:546
blocked weights format
Definition: mkldnn_types.h:434
blocked weights format
Definition: mkldnn_types.h:290
mkldnn_stream_kind_t
Kinds of streams.
Definition: mkldnn_types.h:1280
Definition: mkldnn.hpp:271
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_int_output_round_mode(mkldnn_primitive_attr_t attr, mkldnn_round_mode_t round_mode)
Sets output rounding mode round_mode for integer operations for a given attr.
4D weights tensor with physical layout hwio, used in TensorFlow.
Definition: mkldnn_types.h:202
A wrapper structure to specify a particular output of a primitive.
Definition: mkldnn_types.h:1191
Winograd convolution.
Definition: mkldnn_types.h:528
Definition: mkldnn.hpp:246
Definition: mkldnn.hpp:344
Eltwise: linear.
Definition: mkldnn_types.h:548
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1840
bfloat 16-bit.
Definition: mkldnn_types.h:86
mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(mkldnn_softmax_desc_t *softmax_desc, const mkldnn_memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc, int softmax_axis)
Initializes a softmax_desc for backward propagation using memory descriptors diff_desc and data_desc...
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1912
reorder(const primitive::at &input, const memory &output)
Definition: mkldnn.hpp:1020
Eltwise: logistic.
Definition: mkldnn_types.h:554
Definition: mkldnn.hpp:2687
Direct convolution.
Definition: mkldnn_types.h:526
Primitive iterator passed over last primitive descriptor.
Definition: mkldnn_types.h:64
Definition: mkldnn.hpp:339
Definition: mkldnn.hpp:270
lrn_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &workspace, const memory &dst)
Definition: mkldnn.hpp:2082
source gradient memory primitive desc
Definition: mkldnn_types.h:1266
mkldnn_alg_kind_t cell_kind
RNN cell kind.
Definition: mkldnn_types.h:1002
Definition: mkldnn.hpp:1501
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2689
Definition: mkldnn_types.h:1019
An opaque structure for primitive descriptor attributes.
Definition: mkldnn.hpp:313
blocked data format
Definition: mkldnn_types.h:279
mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a pooling descriptor pool_desc for forward propagation using prop_kind (possible values a...
blocked weights format
Definition: mkldnn_types.h:345
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, int local_size, float alpha, float beta, float k)
Definition: mkldnn.hpp:2050
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:2659
mkldnn_rnn_cell_desc_t c_rnn_cell_
Definition: mkldnn.hpp:2997
bool operator!=(const primitive_desc &other) const
Definition: mkldnn.hpp:826
runtime estimation (seconds)
Definition: mkldnn_types.h:1239
blocked weights format
Definition: mkldnn_types.h:418
bool operator==(const T other) const
Definition: mkldnn.hpp:61
A (in-place) concat primitive.
Definition: mkldnn_types.h:499
mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream, mkldnn_stream_kind_t stream_kind)
Creates an execution stream of stream_kind.
primitive_desc get_primitive_desc() const
Returns the descriptor of the memory primitive.
Definition: mkldnn.hpp:877
blocked weights format
Definition: mkldnn_types.h:312
LSTM cell.
Definition: mkldnn_types.h:571
blocked weights format
Definition: mkldnn_types.h:293
mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(mkldnn_batch_normalization_desc_t *bnrm_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a batch normalization descriptor bnrm_desc for backward propagation with respect to data ...
Definition: mkldnn_types.h:1028
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2507
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2833
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2836
Undefined data type, used for empty memory descriptors.
Definition: mkldnn_types.h:74
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:372
Definition: mkldnn.hpp:1837
16-bit signed integer.
Definition: mkldnn_types.h:80
Definition: mkldnn.hpp:2321
A shuffle primitive.
Definition: mkldnn_types.h:495
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:319
mkldnn_shuffle_desc_t data
Definition: mkldnn.hpp:3261
primitive_desc()
Definition: mkldnn.hpp:798
int len() const
Definition: mkldnn.hpp:376
mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(const_mkldnn_primitive_t primitive, const_mkldnn_primitive_desc_t *primitive_desc)
Retrieves a reference to the primitive_desc descriptor of given primitive.
blocked weights format
Definition: mkldnn_types.h:328
primitive_desc(const memory::desc &output, const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1190
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2821
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(const_mkldnn_post_ops_t post_ops, int index, float *scale, mkldnn_alg_kind_t *alg, float *alpha, float *beta)
Gets the eltwise parameters of the post operation with index index in the sequence of post_ops...
blocked data format
Definition: mkldnn_types.h:271
Definition: mkldnn.hpp:242
blocked weights format
Definition: mkldnn_types.h:347
mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(const_mkldnn_post_ops_t post_ops, int index, float *scale)
Gets the parameters of the accumulation (sum) post operation with index index in the sequence of post...
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1502
blocked weights format
Definition: mkldnn_types.h:337
A (out-of-place) concat primitive.
Definition: mkldnn_types.h:497
blocked weights format
Definition: mkldnn_types.h:358
Fuse with ReLU.
Definition: mkldnn_types.h:621
Definition: mkldnn.hpp:260
Definition: mkldnn.hpp:279
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: mkldnn.hpp:520
mkldnn_query_t
Primitive descriptor query specification.
Definition: mkldnn_types.h:1230
A descriptor of a Batch Normalization operation.
Definition: mkldnn_types.h:933
static engine query(const primitive_desc &pd)
Definition: mkldnn.hpp:552
Definition: mkldnn.hpp:3035
blocked weights format
Definition: mkldnn_types.h:373
deconvolution_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2023
blocked data format
Definition: mkldnn_types.h:278
blocked weights format
Definition: mkldnn_types.h:289
A sum primitive.
Definition: mkldnn_types.h:501
blocked weights format
Definition: mkldnn_types.h:360
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2782
Definition: mkldnn.hpp:303
blocked weights format
Definition: mkldnn_types.h:413
eltwise_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2345
blocked weights format
Definition: mkldnn_types.h:296
unsigned flags
Definition: mkldnn_types.h:960
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create_v2(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output, const_mkldnn_primitive_attr_t attr)
Initializes a reorder_primitive_desc using an attr attribute and descriptors of input and output memo...
blocked weights format
Definition: mkldnn_types.h:295
blocked weights format
Definition: mkldnn_types.h:363
Definition: mkldnn.hpp:2995
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: mkldnn_types.h:530
softmax_backward(const primitive_desc &aprimitive_desc, const primitive::at &dst, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2471
blocked weights format
Definition: mkldnn_types.h:284
Definition: mkldnn.hpp:3036
Definition: mkldnn.hpp:258
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2335
mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated deconvolution descriptor conv_desc for backward propagation with respect to dat...
blocked weights format
Definition: mkldnn_types.h:420
mkldnn_status_t MKLDNN_API mkldnn_stream_rerun(mkldnn_stream_t stream, mkldnn_primitive_t *error_primitive)
Reruns all the primitives within the stream.
2D weights tensor with physical layout io.
Definition: mkldnn_types.h:187
memory consumption – extra (scratch) memory, additional to all inputs and outputs memory (bytes) ...
Definition: mkldnn_types.h:1240
blocked weights format
Definition: mkldnn_types.h:353
An batch normalization primitive.
Definition: mkldnn_types.h:515
A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base class for primitive (mkldnn_p...
Definition: mkldnn.hpp:55
Definition: mkldnn_types.h:524
engine(kind akind, size_t index)
Constructs an engine.
Definition: mkldnn.hpp:530
Definition: mkldnn.hpp:2320
A descriptor of a pooling operation.
Definition: mkldnn_types.h:872
Definition: mkldnn.hpp:3307
Definition: mkldnn.hpp:273
Definition: mkldnn.hpp:274
engine get_engine()
Definition: mkldnn.hpp:830
error(mkldnn_status_t astatus, std::string amessage, mkldnn_primitive_t aerror_primitive=0)
Constructs an error instance.
Definition: mkldnn.hpp:173
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1999
const char * impl_info_str() const
Returns implementation name.
Definition: mkldnn.hpp:1299
deconvolution descriptor
Definition: mkldnn_types.h:1251
std::vector< const_mkldnn_primitive_desc_t > cpp_to_c(std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1180
blocked weights format
Definition: mkldnn_types.h:366
shuffle_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:3278
primitive_desc(const memory::primitive_desc &input, const memory::primitive_desc &output)
Definition: mkldnn.hpp:987
primitive_desc(const desc &desc, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2273
mkldnn_memory_desc_t data
The underlying C API data structure.
Definition: mkldnn.hpp:770
mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(const_mkldnn_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor.
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1463
engine get_engine()
Definition: mkldnn.hpp:1006
int MKLDNN_API mkldnn_primitive_desc_query_s32(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for signed 32bit int.
8-bit signed integer.
Definition: mkldnn_types.h:82
mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(mkldnn_primitive_desc_t *reorder_primitive_desc, const_mkldnn_primitive_desc_t input, const_mkldnn_primitive_desc_t output)
Initializes a reorder_primitive_desc using descriptors of input and output memory primitives...
The data in padding regions is zero.
Definition: mkldnn_types.h:454
int MKLDNN_API mkldnn_rnn_cell_get_states_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of states of a particular rnn_cell_desc.
Definition: mkldnn.hpp:2334
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Definition: mkldnn.hpp:2931
source memory primitive desc
Definition: mkldnn_types.h:1265
mkldnn_primitive_kind_t
Kinds of primitives.
Definition: mkldnn_types.h:485
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const deconvolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1885
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1972
Definition: mkldnn.hpp:3238
Winograd deconvolution.
Definition: mkldnn_types.h:534
Definition: mkldnn.hpp:248
number of inputs expected
Definition: mkldnn_types.h:1236
mkldnn_softmax_desc_t data
Definition: mkldnn.hpp:2411
Definition: mkldnn.hpp:346
Definition: mkldnn.hpp:3060
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2510
desc(prop_kind aprop_kind, algorithm alg_kind, const memory::desc &src_desc, T alpha=0, T beta=0)
Definition: mkldnn.hpp:2324
An unspecified engine.
Definition: mkldnn_types.h:1282
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:1795
void * get_data_handle() const
Returns a handle of the data contained in the memory primitive. On the CPU engine, this is a pointer to the allocated memory.
Definition: mkldnn.hpp:890
A view primitive.
Definition: mkldnn_types.h:491
size_t MKLDNN_API mkldnn_memory_primitive_desc_get_size(const_mkldnn_primitive_desc_t memory_primitive_desc)
Returns the size (in bytes) that is required for given memory_primitive_desc.
Definition: mkldnn.hpp:3107
Definition: mkldnn.hpp:262
Definition: mkldnn.hpp:329
Definition: mkldnn.hpp:3141
blocked weights format
Definition: mkldnn_types.h:327
mkldnn_primitive_kind_t convert_to_c(primitive::kind akind)
Definition: mkldnn.hpp:154
blocked data format
Definition: mkldnn_types.h:274
Definition: mkldnn.hpp:341
Definition: mkldnn.hpp:332
Definition: mkldnn.hpp:324
Definition: mkldnn.hpp:334
Average pooling exclude padding.
Definition: mkldnn_types.h:562
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops)
Returns post_ops for given attr.
mkldnn_status_t MKLDNN_API mkldnn_primitive_create(mkldnn_primitive_t *primitive, const_mkldnn_primitive_desc_t primitive_desc, const mkldnn_primitive_at_t *inputs, const_mkldnn_primitive_t *outputs)
Creates a primitive using a primitive_desc descriptor and arrays of inputs and outputs.
primitive::kind kind(int index) const
Definition: mkldnn.hpp:378
Definition: mkldnn_types.h:998
Forward data propagation (inference mode).
Definition: mkldnn_types.h:468
primitive_attr get_primitive_attr() const
Definition: mkldnn.hpp:1285
6D grouped weights tensor with the physical layout goidhw, used in Caffe.
Definition: mkldnn_types.h:238
5D weights tensor with physical layout iodhw, used in Caffe.
Definition: mkldnn_types.h:214
A class that provides the destructor for an Intel(R) MKL-DNN C handle.
Definition: mkldnn.hpp:40
data_type
Data type specification. See mkldnn_data_type_t for a detailed description.
Definition: mkldnn.hpp:595
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const memory &dst)
Definition: mkldnn.hpp:2550
Direct deconvolution.
Definition: mkldnn_types.h:532
Eltwise: abs.
Definition: mkldnn_types.h:544
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance)
Definition: mkldnn.hpp:2572
blocked weights format
Definition: mkldnn_types.h:388
pooling_backward(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &workspace, const memory &diff_src)
Definition: mkldnn.hpp:2298
blocked weights format
Definition: mkldnn_types.h:313
A memory descriptor.
Definition: mkldnn.hpp:767
deconvolution_backward_data(const primitive_desc &aprimitive_desc, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src)
Definition: mkldnn.hpp:1894
5D grouped weights tensor with the physical layout hwigo, used in TensorFlow.
Definition: mkldnn_types.h:231
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2338
blocked weights format
Definition: mkldnn_types.h:410
bool operator!=(mkldnn_data_type_t a, memory::data_type b)
Definition: mkldnn.hpp:954
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Definition: mkldnn.hpp:481
handle(T t=0, bool weak=false)
Constructs a C handle wrapper.
Definition: mkldnn.hpp:67
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for forward propagation using prop_kind (possi...
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: mkldnn_types.h:538
mkldnn_inner_product_desc_t data
Definition: mkldnn.hpp:2920
mkldnn_status_t status
Definition: mkldnn.hpp:162
deconvolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1822
blocked weights format with additional buffer with size equal to the number of output channels and co...
Definition: mkldnn_types.h:409
T get() const
Returns the value of the underlying C handle.
Definition: mkldnn.hpp:85
blocked weights format
Definition: mkldnn_types.h:401
mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine)
Destroys an engine.
view(const primitive_desc &view_pd, primitive::at input)
Definition: mkldnn.hpp:1072
blocked weights format
Definition: mkldnn_types.h:367
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1932
blocked weights format
Definition: mkldnn_types.h:364
2D data tensor.
Definition: mkldnn_types.h:160
primitive_desc(const desc &desc, const engine &e, const batch_normalization_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2703
blocked weights format
Definition: mkldnn_types.h:321
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Definition: mkldnn.hpp:2810
mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a dilated convolution descriptor conv_desc for backward propagation with respect to data ...
bool wait(bool block=true)
Waits for all computations submitted to the stream to complete.
Definition: mkldnn.hpp:3357
mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *diff_data_desc, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for backward propagation using alg_kind, memory descriptors data_desc and dif...
Primitive or engine failed on execution.
Definition: mkldnn_types.h:66
memory descriptor for memory and view
Definition: mkldnn_types.h:1249
view(memory input, memory::dims dims, memory::dims offsets)
Definition: mkldnn.hpp:1081
Definition: mkldnn.hpp:266
An LRN primitive.
Definition: mkldnn_types.h:513
Definition: mkldnn_types.h:1024
mkldnn_padding_kind_t
Kinds of padding.
Definition: mkldnn_types.h:452
rnn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src_layer, const primitive::at &src_iter, const primitive::at &weights_layer, const primitive::at &weights_iter, const primitive::at &bias, const primitive::at &dst_layer, const primitive::at &dst_iter, const memory &diff_src_layer, const memory &diff_src_iter, const memory &diff_weights_layer, const memory &diff_weights_iter, const memory &diff_bias, const primitive::at &diff_dst_layer, const primitive::at &diff_dst_iter, const primitive::at &workspace)
Definition: mkldnn.hpp:3169
Lazy stream.
Definition: mkldnn_types.h:1286
Definition: mkldnn.hpp:333
desc(const memory::desc &diff_desc, const memory::desc &data_desc, int softmax_axis)
Definition: mkldnn.hpp:2448
blocked weights format
Definition: mkldnn_types.h:415
Definition: mkldnn.hpp:305
void get_output_scales(int &mask, std::vector< float > &scales) const
Definition: mkldnn.hpp:440
blocked weights format
Definition: mkldnn_types.h:286
desc(algorithm kind)
Definition: mkldnn.hpp:3005
primitive_desc(const desc &desc, const engine &e, const rnn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:3142
5D RNN weights tensor in the format (num_layers, num_directions, num_gates, output_channels, input_channels).
Definition: mkldnn_types.h:259
blocked weights format
Definition: mkldnn_types.h:356
const_mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_query_pd(const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what, int index)
Queries primitive descriptor for primitive descriptor.
Definition: mkldnn.hpp:2918
shuffle descriptor
Definition: mkldnn_types.h:1252
Forward data propagation (training mode).
Definition: mkldnn_types.h:464
Definition: mkldnn.hpp:345
primitive_desc(const desc &desc, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2135
inner_product_backward_weights(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at diff_dst, const memory &diff_weights)
Definition: mkldnn.hpp:2957
mkldnn_convolution_desc_t data
Definition: mkldnn.hpp:1575
memory(const primitive &aprimitive)
Constructs a memory primitive from a generic primitive.
Definition: mkldnn.hpp:836
3D data tensor with the physical layout nwc.
Definition: mkldnn_types.h:166
engine get_engine()
Definition: mkldnn.hpp:1151
post_ops()
Definition: mkldnn.hpp:369
An opaque structure to describe a primitive.
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const primitive::at &workspace, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2743
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: mkldnn_types.h:156
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1375
mkldnn_data_type_t
Data type specification.
Definition: mkldnn_types.h:72
Definition: mkldnn.hpp:1500
Definition: mkldnn.hpp:326
Definition: mkldnn.hpp:319
convolution descriptor
Definition: mkldnn_types.h:1250
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1549
A memory primitive descriptor.
Definition: mkldnn.hpp:794
Definition: mkldnn.hpp:315
Definition: mkldnn.hpp:2456
mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc, int local_size, float alpha, float beta, float k)
Initializes an lrn_desc for forward propagation using prop_kind (possible values are mkldnn_forward_t...
blocked weights format
Definition: mkldnn_types.h:342
primitive_desc(const desc &desc, const engine &e, const convolution_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:1545
blocked weights format
Definition: mkldnn_types.h:333
handle & operator=(const handle &other)
Definition: mkldnn.hpp:72
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst)
Definition: mkldnn.hpp:2673
Eltwise: bounded_relu.
Definition: mkldnn_types.h:550
Definition: mkldnn.hpp:2410
#define REG_QUERY_MPD(name, what, idx)
Definition: mkldnn.hpp:1349
Definition: mkldnn_types.h:1021
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst)
Definition: mkldnn.hpp:1485
mkldnn_engine_kind_t
Kinds of engines.
Definition: mkldnn_types.h:1079
Definition: mkldnn_types.h:994
int MKLDNN_API mkldnn_rnn_cell_get_gates_count(const mkldnn_rnn_cell_desc_t *rnn_cell_desc)
Returns the number of gates of a particular rnn_cell_desc.
Queried element is not required for given primitive.
Definition: mkldnn_types.h:68
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:3061
blocked weights format
Definition: mkldnn_types.h:437
bool operator!=(const T other) const
Definition: mkldnn.hpp:62
blocked weights format
Definition: mkldnn_types.h:385
Memory primitive that describes the data.
Definition: mkldnn.hpp:580
Weights format used in 8bit Winograd convolution.
Definition: mkldnn_types.h:442
Definition: mkldnn.hpp:328
primitive_desc(const desc &desc, const engine &e)
Definition: mkldnn.hpp:2071
Definition: mkldnn.hpp:2109
Definition: mkldnn.hpp:302
Round nearest.
Definition: mkldnn_types.h:92
blocked weights format
Definition: mkldnn_types.h:436
Definition: mkldnn.hpp:243
batch_normalization_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &mean, const primitive::at &variance, const primitive::at &diff_dst, const primitive::at &weights, const memory &diff_src, const memory &diff_weights)
Definition: mkldnn.hpp:2724
Definition: mkldnn.hpp:1711
const void * const_mkldnn_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: mkldnn_types.h:711
static mkldnn_stream_kind_t convert_to_c(kind akind)
Definition: mkldnn.hpp:3314
blocked weights format
Definition: mkldnn_types.h:285
blocked weights format
Definition: mkldnn_types.h:431
Definition: mkldnn.hpp:1909
memory::primitive_desc dst_primitive_desc() const
Definition: mkldnn.hpp:1139
mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create_v2(mkldnn_primitive_desc_iterator_t *iterator, const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr, mkldnn_engine_t engine, const_mkldnn_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator for given op_desc, attr, engine, and optionally a hint primit...
Definition: mkldnn.hpp:2492
pooling_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const memory &dst, const memory &workspace)
Definition: mkldnn.hpp:2234
convolution_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const primitive::at &bias, const memory &dst)
Definition: mkldnn.hpp:1472
4D weights tensor with physical layout iohw.
Definition: mkldnn_types.h:211
A reorder primitive.
Definition: mkldnn_types.h:493
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:1798
rnn_direction
Definition: mkldnn.hpp:300
primitive_desc(const std::vector< float > &scales, std::vector< memory::primitive_desc > inputs)
Definition: mkldnn.hpp:1209
blocked weights format
Definition: mkldnn_types.h:411
blocked weights format with additional buffer with size equal to the number of output channels multip...
Definition: mkldnn_types.h:399
blocked weights format
Definition: mkldnn_types.h:336
An unspecified engine.
Definition: mkldnn_types.h:1081
desc(const mkldnn_memory_desc_t &adata)
Constructs a memory descriptor from a C API data structure.
Definition: mkldnn.hpp:790
blocked weights format
Definition: mkldnn_types.h:359
Definition: mkldnn.hpp:1179
int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops)
Returns the length of post operations for given post_ops.
engine get_engine()
Definition: mkldnn.hpp:1069
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const pooling_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2277
blocked weights format
Definition: mkldnn_types.h:412
blocked weights format
Definition: mkldnn_types.h:387
mkldnn_alg_kind_t
Kinds of algorithms.
Definition: mkldnn_types.h:523
primitive_desc(const desc &desc, const engine &e, const inner_product_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2943
Definition: mkldnn.hpp:263
inner product descriptor
Definition: mkldnn_types.h:1258
blocked weights format
Definition: mkldnn_types.h:394
A pooling primitive.
Definition: mkldnn_types.h:511
weights memory primitive descriptor desc
Definition: mkldnn_types.h:1267
output memory primitive desc
Definition: mkldnn_types.h:1264
Definition: mkldnn.hpp:2272
blocked weights format
Definition: mkldnn_types.h:417
blocked weights format
Definition: mkldnn_types.h:349
5D weights tensor with physical layout dhwio, used in TensorFlow.
Definition: mkldnn_types.h:217
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
Definition: mkldnn.hpp:2074
mkldnn_batch_normalization_desc_t data
Definition: mkldnn.hpp:2494
Definition: mkldnn.hpp:986
mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(mkldnn_primitive_t primitive)
Deletes a primitive.
Definition: mkldnn.hpp:335
std::string message
Definition: mkldnn.hpp:163
Definition: mkldnn.hpp:3226
mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides, const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind)
Initializes a deconvolution descriptor conv_desc for backward propagation with respect to weights usi...
mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc, const mkldnn_memory_desc_t *diff_src_layer_desc, const mkldnn_memory_desc_t *diff_src_iter_desc, const mkldnn_memory_desc_t *diff_weights_layer_desc, const mkldnn_memory_desc_t *diff_weights_iter_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_layer, const mkldnn_memory_desc_t *diff_dst_iter_desc)
Initializes a rnn descriptor rnn_desc for backward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
primitive_desc(const desc &desc, const engine &e, const eltwise_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2374
Definition: mkldnn.hpp:316
blocked weights format
Definition: mkldnn_types.h:324
handle(const handle &other)
Definition: mkldnn.hpp:71
Forward data propagation (alias for mkldnn_forward_training)
Definition: mkldnn_types.h:472
3D RNN data tensor in the format (batch, seq_length, input channels).
Definition: mkldnn_types.h:240
mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(mkldnn_primitive_attr_t attr, int count, int mask, const float *scales)
Sets output scales for primitive operations.
Definition: mkldnn.hpp:241
lrn descriptor
Definition: mkldnn_types.h:1256
workspace memory primitive desc
Definition: mkldnn_types.h:1271
lrn_backward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &diff_dst, const memory &diff_src)
Definition: mkldnn.hpp:2162
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1636
bool next_impl()
Advances the next implementation for the given op descriptor.
Definition: mkldnn.hpp:1313
mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(mkldnn_inner_product_desc_t *ip_desc, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *diff_weights_desc, const mkldnn_memory_desc_t *diff_bias_desc, const mkldnn_memory_desc_t *diff_dst_desc)
Initializes an inner product descriptor ip_desc for backward propagation with respect to weights usin...
blocked weights format
Definition: mkldnn_types.h:282
blocked weights format
Definition: mkldnn_types.h:291
mkldnn_deconvolution_desc_t data
Definition: mkldnn.hpp:1713
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, T epsilon, unsigned flags)
Definition: mkldnn.hpp:2691
blocked weights format
Definition: mkldnn_types.h:343
Definition: mkldnn.hpp:224
weights format with additional buffer size equal to the number of output channels and containing the ...
Definition: mkldnn_types.h:308
primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e, const lrn_forward::primitive_desc &hint_fwd_pd)
Definition: mkldnn.hpp:2139
float get_clipping() const
Definition: mkldnn.hpp:3020
weights grad.
Definition: mkldnn_types.h:1268
4D data tensor with the physical layout nchw, used in Caffe.
Definition: mkldnn_types.h:169
Definition: mkldnn.hpp:322
mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_rnn_cell_desc_t *rnn_cell_desc, const mkldnn_rnn_direction_t direction, const mkldnn_memory_desc_t *src_layer_desc, const mkldnn_memory_desc_t *src_iter_desc, const mkldnn_memory_desc_t *weights_layer_desc, const mkldnn_memory_desc_t *weights_iter_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_layer_desc, const mkldnn_memory_desc_t *dst_iter_desc)
Initializes a rnn descriptor rnn_desc for forward propagation using prop_kind, rnn_cell_desc, direction, and memory descriptors.
void append_eltwise(float scale, algorithm alg, float alpha, float beta)
Definition: mkldnn.hpp:396
primitive kind
Definition: mkldnn_types.h:1234
blocked data format
Definition: mkldnn_types.h:272
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims dilates, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1858
int get_state_count() const
Definition: mkldnn.hpp:3029
blocked weights format
Definition: mkldnn_types.h:320
Definition: mkldnn.hpp:318
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:2250
batch_normalization_forward(const primitive_desc &aprimitive_desc, const primitive::at &src, const primitive::at &weights, const memory &dst, const memory &mean, const memory &variance, const memory &workspace)
Definition: mkldnn.hpp:2587
kind
Definition: mkldnn.hpp:3310
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims strides, const memory::dims padding_l, const memory::dims padding_r, const padding_kind apadding_kind)
Definition: mkldnn.hpp:1503
Definition: mkldnn.hpp:340
desc(prop_kind aprop_kind, rnn_cell::desc cell, const rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc)
Definition: mkldnn.hpp:3038
mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *src_desc, const mkldnn_memory_desc_t *weights_desc, const mkldnn_memory_desc_t *bias_desc, const mkldnn_memory_desc_t *dst_desc)
Initializes an inner product descriptor ip_desc for forward propagation using prop_kind (possible val...