-
Notifications
You must be signed in to change notification settings - Fork 264
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
28 changed files
with
3,023 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#include <ATen/ATen.h> | ||
|
||
#include <ATen/NativeFunctions.h> | ||
#include <ATen/Parallel.h> | ||
#include <ATen/native/ReduceOpsUtils.h> | ||
#include <ATen/native/cpu/utils.h> | ||
#include <ATen/record_function.h> | ||
#include <c10/util/irange.h> | ||
|
||
#include "SelectiveScan.h" | ||
#include "utils/library.h" | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
IPEX_DEFINE_DISPATCH(selective_scan_kernel_stub); | ||
IPEX_DEFINE_DISPATCH(selective_state_update_kernel_stub); | ||
|
||
/** | ||
* Does selective scan algorithm in Mamba Paper. | ||
* Paper: https://arxiv.org/abs/2312.00752 | ||
* Official Python Implementation: | ||
* selective_scan_ref: | ||
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L113 | ||
* @param u: (batch, dim, len) or (batch, len, dim) | ||
* @param delta: same shape as u | ||
* @param A: (dim, dstate) or (dstate, dim) | ||
* @param B: (batch, dstate, len) or (batch, dstate, 2len) or (battch, ngroups, | ||
* dstate, len) | ||
* @param C: (batch, dstate, len) or (batch, dstate, 2len) or (battch, ngroups, | ||
* dstate, len) | ||
* @param D: (dim,) or None | ||
* @param z: (batch, dim, len) or None | ||
* @param delta_bias: (dim,) or None | ||
* @param delta_softplus: bool | ||
* @param return_last_state: bool | ||
* @return: out: (batch, dim, len), last_state: (batch, dim, dstate) | ||
*/ | ||
std::tuple<at::Tensor, at::Tensor> selective_scan( | ||
const at::Tensor& u, | ||
const at::Tensor& delta, | ||
const at::Tensor& A, | ||
const at::Tensor& B, | ||
const at::Tensor& C, | ||
const c10::optional<at::Tensor>& D, | ||
const c10::optional<at::Tensor>& z, | ||
const c10::optional<at::Tensor>& delta_bias, | ||
bool delta_softplus, | ||
bool return_last_state) { | ||
RECORD_FUNCTION("selective_scan_fn", c10::ArrayRef<c10::IValue>({})); | ||
return selective_scan_kernel_stub( | ||
kCPU, | ||
u, | ||
delta, | ||
A, | ||
B, | ||
C, | ||
D, | ||
z, | ||
delta_bias, | ||
delta_softplus, | ||
return_last_state); | ||
} | ||
|
||
/** | ||
* Official Python Implementation: | ||
* selective_state_update_ref: | ||
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py#L219 | ||
* @param state: (batch, dim, dstate) or (batch, nheads, dim, dstate) | ||
* @param x: (batch, dim) or (batch, nheads, dim) | ||
* @param dt: (batch, dim) or (batch, nheads, dim) | ||
* @param A: (dim, dstate) or (nheads, dim, dstate) or (dstate, dim) or (nheads, | ||
* dstate, dim) | ||
* @param B: (batch, dstate) or (batch, ngroups, dstate) | ||
* @param C: (batch, dstate) or (batch, ngroups, dstate) | ||
* @param D: (dim,) or (nheads, dim) or None | ||
* @param z: (batch, dim) or (batch, nheads, dim) or None | ||
* @param dt_bias: (dim,) or (nheads, dim) or None | ||
* @param dt_softplus: bool | ||
* @return: out: (batch, dim) or (batch, nheads, dim) | ||
*/ | ||
at::Tensor selective_state_update( | ||
const at::Tensor& state, | ||
const at::Tensor& x, | ||
const at::Tensor& dt, | ||
const at::Tensor& A, | ||
const at::Tensor& B, | ||
const at::Tensor& C, | ||
const c10::optional<at::Tensor>& D, | ||
const c10::optional<at::Tensor>& z, | ||
const c10::optional<at::Tensor>& dt_bias, | ||
bool dt_softplus) { | ||
RECORD_FUNCTION("selective_state_update", c10::ArrayRef<c10::IValue>({})); | ||
return selective_state_update_kernel_stub( | ||
kCPU, state, x, dt, A, B, C, D, z, dt_bias, dt_softplus); | ||
} | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex | ||
|
||
namespace { | ||
|
||
IPEX_TORCH_LIBRARY_FRAGMENT(torch_ipex, m) { | ||
m.def( | ||
"selective_scan_fn(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor)"); | ||
m.impl( | ||
"selective_scan_fn", | ||
c10::DispatchKey::CPU, | ||
torch_ipex::cpu::selective_scan); | ||
m.def( | ||
"selective_state_update(Tensor state, Tensor x, Tensor dt, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? dt_bias, bool dt_softplus) -> (Tensor)"); | ||
m.impl( | ||
"selective_state_update", | ||
c10::DispatchKey::CPU, | ||
torch_ipex::cpu::selective_state_update); | ||
} | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#pragma once | ||
|
||
#include <ATen/Tensor.h> | ||
#include <dyndisp/DispatchStub.h> | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
|
||
std::tuple<at::Tensor, at::Tensor> selective_scan( | ||
const at::Tensor& u, | ||
const at::Tensor& delta, | ||
const at::Tensor& A, | ||
const at::Tensor& B, | ||
const at::Tensor& C, | ||
const c10::optional<at::Tensor>& D, | ||
const c10::optional<at::Tensor>& z, | ||
const c10::optional<at::Tensor>& delta_bias, | ||
bool delta_softplus, | ||
bool return_last_state); | ||
at::Tensor selective_state_update( | ||
const at::Tensor& state, | ||
const at::Tensor& x, | ||
const at::Tensor& dt, | ||
const at::Tensor& A, | ||
const at::Tensor& B, | ||
const at::Tensor& C, | ||
const c10::optional<at::Tensor>& D, | ||
const c10::optional<at::Tensor>& z, | ||
const c10::optional<at::Tensor>& dt_bias, | ||
bool dt_softplus); | ||
|
||
using selective_scan_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)( | ||
const at::Tensor& u, | ||
const at::Tensor& delta, | ||
const at::Tensor& A, | ||
const at::Tensor& B, | ||
const at::Tensor& C, | ||
const c10::optional<at::Tensor>& D, | ||
const c10::optional<at::Tensor>& z, | ||
const c10::optional<at::Tensor>& delta_bias, | ||
bool delta_softplus, | ||
bool return_last_state); | ||
using selective_state_update_fn = at::Tensor (*)( | ||
const at::Tensor& state, | ||
const at::Tensor& x, | ||
const at::Tensor& dt, | ||
const at::Tensor& A, | ||
const at::Tensor& B, | ||
const at::Tensor& C, | ||
const c10::optional<at::Tensor>& D, | ||
const c10::optional<at::Tensor>& z, | ||
const c10::optional<at::Tensor>& dt_bias, | ||
bool dt_softplus); | ||
IPEX_DECLARE_DISPATCH(selective_scan_kernel_fn, selective_scan_kernel_stub); | ||
IPEX_DECLARE_DISPATCH( | ||
selective_state_update_fn, | ||
selective_state_update_kernel_stub); | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
#include <aten/Conv.h> | ||
#include "mkl.h" | ||
#include "vec/vec.h" | ||
|
||
namespace torch_ipex { | ||
namespace cpu { | ||
namespace { | ||
template <typename T> | ||
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update_kernel_inner( | ||
const at::Tensor& hidden_states, | ||
const at::Tensor& conv_states, | ||
const at::Tensor& conv_weights, | ||
const c10::optional<at::Tensor>& conv_bias, | ||
bool silu_activation) { | ||
auto bs = conv_states.size(0); | ||
auto channels = conv_states.size(1); | ||
auto kernel_size = conv_states.size(2); | ||
auto has_bias = conv_bias.has_value(); | ||
auto bias_ptr = has_bias ? conv_bias.value().data_ptr<T>() : nullptr; | ||
auto conv_states_ptr = conv_states.data_ptr<T>(); | ||
auto conv_weights_ptr = conv_weights.data_ptr<T>(); | ||
auto hidden_states_ptr = hidden_states.data_ptr<T>(); | ||
auto hidden_states_strideB = hidden_states.stride(0); | ||
auto hidden_states_strideC = hidden_states.stride(1); | ||
auto conv_states_strideB = conv_states.stride(0); | ||
auto conv_states_strideC = conv_states.stride(1); | ||
auto conv_states_strideK = conv_states.stride(2); | ||
auto conv_weights_strideC = conv_weights.stride(0); | ||
#pragma omp parallel for collapse(2) | ||
for (auto bi = 0; bi < bs; bi++) { | ||
for (auto ci = 0; ci < channels; ci++) { | ||
auto conv_weights_start = ci * conv_weights_strideC; | ||
float out = 0.0f; | ||
auto conv_states_start = | ||
bi * conv_states_strideB + ci * conv_states_strideC; | ||
for (auto k = 1; k < kernel_size; k++) { | ||
auto conv_states_idx = conv_states_start + k * conv_states_strideK; | ||
out += conv_weights_ptr[conv_weights_start + k - 1] * | ||
conv_states_ptr[conv_states_idx]; | ||
conv_states_ptr[conv_states_idx - conv_states_strideK] = | ||
conv_states_ptr[conv_states_idx]; | ||
} | ||
auto hidden_states_idx = | ||
bi * hidden_states_strideB + ci * hidden_states_strideC; | ||
out += hidden_states_ptr[hidden_states_idx] * | ||
conv_weights_ptr[conv_weights_start + kernel_size - 1]; | ||
conv_states_ptr | ||
[conv_states_start + (kernel_size - 1) * conv_states_strideK] = | ||
hidden_states_ptr[hidden_states_idx]; | ||
if (has_bias) { | ||
out += bias_ptr[ci]; | ||
} | ||
if (silu_activation) { | ||
out = out / (1 + expf(-out)); | ||
} | ||
hidden_states_ptr[hidden_states_idx] = out; | ||
} | ||
} | ||
return std::make_tuple(std::move(hidden_states), std::move(conv_states)); | ||
} | ||
|
||
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update_kernel_impl( | ||
const at::Tensor& hidden_states, | ||
const at::Tensor& conv_states, | ||
const at::Tensor& conv_weights, | ||
const c10::optional<at::Tensor>& conv_bias, | ||
bool silu_activation) { | ||
if (hidden_states.scalar_type() == at::ScalarType::Float) { | ||
return causal_conv1d_update_kernel_inner<float>( | ||
hidden_states, conv_states, conv_weights, conv_bias, silu_activation); | ||
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) { | ||
return causal_conv1d_update_kernel_inner<at::BFloat16>( | ||
hidden_states, conv_states, conv_weights, conv_bias, silu_activation); | ||
} else if (hidden_states.scalar_type() == at::ScalarType::Half) { | ||
return causal_conv1d_update_kernel_inner<at::Half>( | ||
hidden_states, conv_states, conv_weights, conv_bias, silu_activation); | ||
} else { | ||
TORCH_CHECK( | ||
false, | ||
"Only support bfloat16, float16 and float for causal_conv1d_update"); | ||
} | ||
} | ||
} // anonymous namespace | ||
IPEX_REGISTER_DISPATCH( | ||
causal_conv1d_update_kernel_stub, | ||
&causal_conv1d_update_kernel_impl); | ||
|
||
} // namespace cpu | ||
} // namespace torch_ipex |
Oops, something went wrong.