mirror of https://github.com/alibaba/MNN.git
				
				
				
			
		
			
				
	
	
		
			213 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			C++
		
	
	
	
			
		
		
	
	
			213 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			C++
		
	
	
	
//
 | 
						|
//  BlstmComputer.hpp
 | 
						|
//  MNN
 | 
						|
//
 | 
						|
//  Created by MNN on 2020/04/30.
 | 
						|
//  Copyright © 2018, Alibaba Group Holding Limited
 | 
						|
//
 | 
						|
 | 
						|
#ifndef BLSTMCOMPUTER_hpp
 | 
						|
#define BLSTMCOMPUTER_hpp
 | 
						|
 | 
						|
#include <memory>
 | 
						|
#include <vector>
 | 
						|
 | 
						|
#include "MNN/ErrorCode.hpp"
 | 
						|
#include "MNN_generated.h"
 | 
						|
#include "backend/cpu/CPUBackend.hpp"
 | 
						|
#include "core/Concurrency.h"
 | 
						|
#include "core/Macro.h"
 | 
						|
#include "core/TensorUtils.hpp"
 | 
						|
 | 
						|
using std::shared_ptr;
 | 
						|
using std::vector;
 | 
						|
 | 
						|
namespace MNN {
 | 
						|
 | 
						|
class BlstmComputer {
 | 
						|
  /**
 | 
						|
      Blstm:
 | 
						|
      Xt = input at timestep t
 | 
						|
      Ct-1 = cell state of last time step
 | 
						|
      O =  sigmoid activation
 | 
						|
      x = matrix product
 | 
						|
      * =  matrix dot product
 | 
						|
      Input gate:   It = Og(Xt x Wi + Ht-1 x Ui + Bi)
 | 
						|
      Next gate:    Nt = tanh(Xt x Wn + Ht-1 x Un + Bn)
 | 
						|
      Forget gate:  Ft = Og(Xt x Wf + Ht-1 x Uf + Bf)
 | 
						|
      Output gate:  Ot = Og(Xt x Wo + Ht-1 x Uo + Bo)
 | 
						|
      Cell state:   Ct = Nt * It + Ct-1 * Ft
 | 
						|
      Hidden state: Ht = tanh(Ct) * Ot
 | 
						|
      output : Ht
 | 
						|
 | 
						|
      Suppose input is a (Batch, Timestep, Feature) tensor
 | 
						|
      General usage:
 | 
						|
          (1). Construct a BlstmComputer* blstm = new BlstmComputer();
 | 
						|
          (2). Call blstm.importWeights() to import weight into this blstm.
 | 
						|
          (3). Upon every execution, first blstm.onResize(), then
 | 
						|
     blstm.onExecute() This is a single layer blstm. If you want to construct a
 | 
						|
     multi-layer blstm, you can just construct multiple blstm instances with
 | 
						|
     proper args and connect them together.
 | 
						|
  */
 | 
						|
 | 
						|
public:
 | 
						|
  /**
 | 
						|
   * @brief construct the BlstmComputer instance.
 | 
						|
   * @param inDim input dimension, correspond to 'Feature' in input(Batch,
 | 
						|
   * Timestep, Feature)
 | 
						|
   * @param stateSize hidden state & cell state size.
 | 
						|
   * @param bidirectional if this is a bidirectional or unidirectional lstm
 | 
						|
   * @param backend backend
 | 
						|
   */
 | 
						|
  BlstmComputer(int inDim, int stateSize, bool bidirectional,
 | 
						|
                MNN::CPUBackend *backend);
 | 
						|
  virtual ~BlstmComputer();
 | 
						|
  /**
 | 
						|
   * @brief sigmoid activation function
 | 
						|
   */
 | 
						|
  static float sigmoid(float x);
 | 
						|
 | 
						|
  /**
 | 
						|
   * @brief trim tensor into correct storage order. For NCHW and NHWC, data will
 | 
						|
   * be directly copied, interal storage order will not be changed. For NC4HW4,
 | 
						|
   * onCopyBuffer() will be used, interal storage order will be changed.
 | 
						|
   */
 | 
						|
  void trimTensor(Tensor *src_tensor, Tensor *tgt_tensor);
 | 
						|
 | 
						|
  /**
 | 
						|
   * @brief allocate space for all the weights and bias. And import data from
 | 
						|
   weightsVec.
 | 
						|
   * @param weightsVec
 | 
						|
      WeightsVec must has the same order as mWeights. This method will copy each
 | 
						|
   tensor in WeightsVec to corresponding mWeight. for weightsVec[0-3, 12-15],
 | 
						|
   shape = (mInDim, mStateSize) for weightsVec[4-7, 16-19], shape = (mStateSize,
 | 
						|
   mStateSize) for weightsVec[8-11, 20-23], shape = (mStateSize) For
 | 
						|
   bidirectional blstm, WeightsVec's size must equals to 24. For unidirectional
 | 
						|
   lstm, WeightsVec's size must equals to 12. By default, tensor in weightsVec
 | 
						|
   should be a NCHW or NHWC format tensor. If a NC4HW4 is passed, this method
 | 
						|
   will transform it into NCHW format tensor, and the internel storage order
 | 
						|
   might be changed. Thus, user should handle the data storage correctly.
 | 
						|
   */
 | 
						|
  ErrorCode importWeights(const vector<shared_ptr<Tensor>> &weightsVec);
 | 
						|
  /**
 | 
						|
   * @brief Need to be called before every onExecute(). This method will resize
 | 
						|
   * the internal tensors' memory which are used by calculation process.
 | 
						|
   * @param timeSteps, input's timeSteps
 | 
						|
   * @param batchSize, input's batchSize
 | 
						|
   */
 | 
						|
  ErrorCode onResize(int timeSteps, int batchSize);
 | 
						|
  /**
 | 
						|
   * @param input input tensor, shape = (B, T, F). Should be a NCHW or
 | 
						|
   * NHWCtensor. If a NC4HW4 is passed, this method will transform it into NCHW
 | 
						|
   * format, thus the internal storage order will be changed. User
 | 
						|
   * should handle the data storage order correctly.
 | 
						|
   * @param batchLengths length for each data slot in this batch. If current
 | 
						|
   * timestep > length, this data slot's output will be set to 0.
 | 
						|
   * @param initH initial HiddenState of this blstm. Each element of initH
 | 
						|
   * should be a (Batch, mStateSize) tensor. If bidirectional, initH.size() must
 | 
						|
   * = 2. If unidirectional, initH.size() must = 1. If not provide
 | 
						|
   * initH, it will be initialized to all 0.
 | 
						|
   * @param initC initial CellState of this blstm. Each element of initC should
 | 
						|
   * be a (Batch, mStateSize) tensor. If bidirectional, initC.size() must = 2.
 | 
						|
   * If unidirectional, initC.size() must = 1. If not provide
 | 
						|
   * initC, it will be initialized to all 0.
 | 
						|
   */
 | 
						|
  ErrorCode onExecute(Tensor *input, const vector<int> &batchLengths = {},
 | 
						|
                      const vector<shared_ptr<Tensor>> &initH = {},
 | 
						|
                      const vector<shared_ptr<Tensor>> &initC = {});
 | 
						|
 | 
						|
  /**
 | 
						|
   * @brief get the output tensor of this blstm.
 | 
						|
   */
 | 
						|
  shared_ptr<Tensor> output();
 | 
						|
  /**
 | 
						|
   * @brief get backend instance stored in this blstm instance.
 | 
						|
   */
 | 
						|
  CPUBackend *backend();
 | 
						|
 | 
						|
private:
 | 
						|
  int mInDim;          //  dimension for input' Feature
 | 
						|
  int mStateSize;      // dimension for hidden state and cell_state.
 | 
						|
  bool mBidirectional; // uni or bidirectional of this blstm
 | 
						|
  int mBatchSize = 0;
 | 
						|
  int mTimeSteps = 0;
 | 
						|
  shared_ptr<Tensor> mInput;  // (B, T, F) tensor
 | 
						|
  shared_ptr<Tensor> mOutput; // (B, T, F) tensor
 | 
						|
  vector<shared_ptr<Tensor>> mGateInputs;
 | 
						|
  vector<shared_ptr<Tensor>> mGateOutputs;
 | 
						|
  // mHiddenStates[0] is hidden state forward. mHiddenStates[1] = hidden state
 | 
						|
  // backward if bidirectional
 | 
						|
  vector<shared_ptr<Tensor>> mHiddenStates;
 | 
						|
  // mCellStates[0] is cell state forward. mCellStates[1] = cell state backward
 | 
						|
  // if bidirectional
 | 
						|
  vector<shared_ptr<Tensor>> mCellStates;
 | 
						|
  /*
 | 
						|
  mWeights[0] : Wi forward, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[1] : Wn forward, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[2] : Wf forward, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[3] : Wo forward, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[4] : Ui forward, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[5] : Un forward, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[6] : Uf forward, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[7] : Uo forward, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[8] : Bi forward, shape = (mStateSize)
 | 
						|
  mWeights[9] : Bn forward, shape = (mStateSize)
 | 
						|
  mWeights[10] : Bf forward, shape = (mStateSize)
 | 
						|
  mWeights[11] : Bo forward, shape = (mStateSize)
 | 
						|
  mWeights[12] : Wi backward if bidirectional, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[13] : Wn backward if bidirectional, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[14] : Wf backward if bidirectional, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[15] : Wo backward if bidirectional, shape = (mInDim, mStateSize)
 | 
						|
  mWeights[16] : Ui backward if bidirectional, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[17] : Un backward if bidirectional, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[18] : Uf backward if bidirectional, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[19] : Uo backward if bidirectional, shape = (mStateSize, mStateSize)
 | 
						|
  mWeights[20] : Bi backward if bidirectional, shape = (mStateSize)
 | 
						|
  mWeights[21] : Bn backward if bidirectional, shape = (mStateSize)
 | 
						|
  mWeights[22] : Bf backward if bidirectional, shape = (mStateSize)
 | 
						|
  mWeights[23] : Bo backward if bidirectional, shape = (mStateSize)
 | 
						|
  */
 | 
						|
  vector<shared_ptr<Tensor>> mWeights;
 | 
						|
  MNN::CPUBackend *mBackend;
 | 
						|
 | 
						|
  /*
 | 
						|
      To make it more clear for users about how to wrap weights and input of
 | 
						|
    this blstm , we provide a simple example below.
 | 
						|
 | 
						|
      Suppose we have a blstm and input, with
 | 
						|
      Batch = 2, Timestep = 2, F(inDim) = 3, stateSize = 2, bidirectional =
 | 
						|
    true.
 | 
						|
 | 
						|
      Say if you want a input like this:
 | 
						|
                           | - F - |
 | 
						|
      timestep1, batch1    1,  2,  3
 | 
						|
      timestep2, batch1    4,  5,  6
 | 
						|
      timestep1, batch2    7,  8,  9
 | 
						|
      timestep2, batch2    10, 11, 12
 | 
						|
 | 
						|
      If you pass a NCHW or NHWC tensor as input, the internal storage order
 | 
						|
    should be: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12
 | 
						|
 | 
						|
      Also, if you want Wi to be like this:
 | 
						|
      | - stateSize - |
 | 
						|
    --|    1,    2    |
 | 
						|
    F |    3,    4    |
 | 
						|
    --|    5,    6    |
 | 
						|
 | 
						|
      If you pass a NCHW or NHWC tensor as Wi, the internal storage order should
 | 
						|
    be: 1, 2, 3, 4, 5, 6
 | 
						|
 | 
						|
      Then input matrix can then multiply with Wi.
 | 
						|
 | 
						|
      So the general principle of warpping input, weight, initH/initC is:
 | 
						|
          1. if you use NCHW/NHWC as source, make sure interal data is stored -1
 | 
						|
    dim, then -2 dim ....
 | 
						|
          2. if you use NC4HW4 as source, make sure after using onCopyBuffer(),
 | 
						|
    the resulting tensor is stored -1 dim, then -2 dim .... interally.
 | 
						|
  */
 | 
						|
};
 | 
						|
 | 
						|
} // namespace MNN
 | 
						|
 | 
						|
#endif /* BLSTMCOMPUTER_hpp */
 |