Skip to content

Commit

Permalink
Enable optimized Jamba (#3406)
Browse files Browse the repository at this point in the history
  • Loading branch information
blzheng authored Dec 19, 2024
1 parent 1be1dc8 commit 950e509
Show file tree
Hide file tree
Showing 28 changed files with 3,023 additions and 16 deletions.
33 changes: 33 additions & 0 deletions csrc/cpu/aten/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace torch_ipex {
namespace cpu {

IPEX_DEFINE_DISPATCH(causal_conv1d_update_kernel_stub);
std::vector<int64_t> calc_conv_output_size(
at::IntArrayRef input_size,
at::IntArrayRef kernel_size,
Expand Down Expand Up @@ -505,6 +506,32 @@ at::Tensor convolution_forward(
weight_channels_last);
}

/**
* Official Python implementation: causal_conv1d_update_ref:
* https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py#L206
* @param hidden_states (batch, dim) or (batch, dim, seqlen)
* @param conv_states (batch, dim, state_len), where state_len >= width - 1
* @param conv_weights (dim, width)
* @param conv_bias (dim,)
* @param silu_activation If true, apply the SiLU activation function.
* @return (hidden_states, conv_states)
*/
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update(
const at::Tensor& hidden_states,
const at::Tensor& conv_states,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& conv_bias,
bool silu_activation) {
RECORD_FUNCTION("causal_conv1d_update", c10::ArrayRef<c10::IValue>({}));
return causal_conv1d_update_kernel_stub(
kCPU,
hidden_states,
conv_states,
conv_weights,
conv_bias,
silu_activation);
}

} // namespace cpu
} // namespace torch_ipex

Expand Down Expand Up @@ -561,6 +588,12 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
"convolution_forward",
c10::DispatchKey::CPU,
torch_ipex::cpu::convolution_forward_impl);
m.def(
"causal_conv1d_update(Tensor hidden_states, Tensor conv_states, Tensor conv_weights, Tensor? conv_bias, bool silu_activation) -> (Tensor, Tensor)");
m.impl(
"causal_conv1d_update",
c10::DispatchKey::CPU,
torch_ipex::cpu::causal_conv1d_update);
// bw
m.def(
"convolution_backward(Tensor input, Tensor weight, Tensor? bias, Tensor grad_output, bool[3] out_mask, "
Expand Down
17 changes: 17 additions & 0 deletions csrc/cpu/aten/Conv.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <ATen/Tensor.h>
#include <dyndisp/DispatchStub.h>
#include <torch/csrc/autograd/custom_function.h>

#include <ideep.hpp>
Expand Down Expand Up @@ -51,6 +52,13 @@ std::vector<int64_t> calc_conv_output_size(
at::IntArrayRef stride,
at::IntArrayRef dilation);

std::tuple<at::Tensor, at::Tensor> causal_conv1d_update(
const at::Tensor& hidden_states,
const at::Tensor& conv_states,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& conv_bias,
bool silu_activation);

// IPEX customized convolution OP with n-D packed weight
class IPEXConvolutionOp : public torch::autograd::Function<IPEXConvolutionOp> {
public:
Expand Down Expand Up @@ -95,5 +103,14 @@ at::Tensor convolution_forward(
c10::optional<at::IntArrayRef> dilation,
c10::optional<bool> weight_channels_last);

using causal_conv1d_update_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
const at::Tensor& hidden_states,
const at::Tensor& conv_states,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& conv_bias,
bool silu_activation);
IPEX_DECLARE_DISPATCH(
causal_conv1d_update_kernel_fn,
causal_conv1d_update_kernel_stub);
} // namespace cpu
} // namespace torch_ipex
118 changes: 118 additions & 0 deletions csrc/cpu/aten/SelectiveScan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#include <ATen/ATen.h>

#include <ATen/NativeFunctions.h>
#include <ATen/Parallel.h>
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/record_function.h>
#include <c10/util/irange.h>

#include "SelectiveScan.h"
#include "utils/library.h"

namespace torch_ipex {
namespace cpu {

IPEX_DEFINE_DISPATCH(selective_scan_kernel_stub);
IPEX_DEFINE_DISPATCH(selective_state_update_kernel_stub);

/**
* Does selective scan algorithm in Mamba Paper.
* Paper: https://arxiv.org/abs/2312.00752
* Official Python Implementation:
* selective_scan_ref:
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L113
* @param u: (batch, dim, len) or (batch, len, dim)
* @param delta: same shape as u
* @param A: (dim, dstate) or (dstate, dim)
* @param B: (batch, dstate, len) or (batch, dstate, 2len) or (battch, ngroups,
* dstate, len)
* @param C: (batch, dstate, len) or (batch, dstate, 2len) or (battch, ngroups,
* dstate, len)
* @param D: (dim,) or None
* @param z: (batch, dim, len) or None
* @param delta_bias: (dim,) or None
* @param delta_softplus: bool
* @param return_last_state: bool
* @return: out: (batch, dim, len), last_state: (batch, dim, dstate)
*/
std::tuple<at::Tensor, at::Tensor> selective_scan(
const at::Tensor& u,
const at::Tensor& delta,
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor& C,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& z,
const c10::optional<at::Tensor>& delta_bias,
bool delta_softplus,
bool return_last_state) {
RECORD_FUNCTION("selective_scan_fn", c10::ArrayRef<c10::IValue>({}));
return selective_scan_kernel_stub(
kCPU,
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
return_last_state);
}

/**
* Official Python Implementation:
* selective_state_update_ref:
* https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py#L219
* @param state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
* @param x: (batch, dim) or (batch, nheads, dim)
* @param dt: (batch, dim) or (batch, nheads, dim)
* @param A: (dim, dstate) or (nheads, dim, dstate) or (dstate, dim) or (nheads,
* dstate, dim)
* @param B: (batch, dstate) or (batch, ngroups, dstate)
* @param C: (batch, dstate) or (batch, ngroups, dstate)
* @param D: (dim,) or (nheads, dim) or None
* @param z: (batch, dim) or (batch, nheads, dim) or None
* @param dt_bias: (dim,) or (nheads, dim) or None
* @param dt_softplus: bool
* @return: out: (batch, dim) or (batch, nheads, dim)
*/
at::Tensor selective_state_update(
const at::Tensor& state,
const at::Tensor& x,
const at::Tensor& dt,
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor& C,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& z,
const c10::optional<at::Tensor>& dt_bias,
bool dt_softplus) {
RECORD_FUNCTION("selective_state_update", c10::ArrayRef<c10::IValue>({}));
return selective_state_update_kernel_stub(
kCPU, state, x, dt, A, B, C, D, z, dt_bias, dt_softplus);
}

} // namespace cpu
} // namespace torch_ipex

namespace {

IPEX_TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
m.def(
"selective_scan_fn(Tensor u, Tensor delta, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? delta_bias, bool delta_softplus, bool return_last_state) -> (Tensor, Tensor)");
m.impl(
"selective_scan_fn",
c10::DispatchKey::CPU,
torch_ipex::cpu::selective_scan);
m.def(
"selective_state_update(Tensor state, Tensor x, Tensor dt, Tensor A, Tensor B, Tensor C, Tensor? D, Tensor? z, Tensor? dt_bias, bool dt_softplus) -> (Tensor)");
m.impl(
"selective_state_update",
c10::DispatchKey::CPU,
torch_ipex::cpu::selective_state_update);
}

} // namespace
60 changes: 60 additions & 0 deletions csrc/cpu/aten/SelectiveScan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <ATen/Tensor.h>
#include <dyndisp/DispatchStub.h>

namespace torch_ipex {
namespace cpu {

std::tuple<at::Tensor, at::Tensor> selective_scan(
const at::Tensor& u,
const at::Tensor& delta,
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor& C,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& z,
const c10::optional<at::Tensor>& delta_bias,
bool delta_softplus,
bool return_last_state);
at::Tensor selective_state_update(
const at::Tensor& state,
const at::Tensor& x,
const at::Tensor& dt,
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor& C,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& z,
const c10::optional<at::Tensor>& dt_bias,
bool dt_softplus);

using selective_scan_kernel_fn = std::tuple<at::Tensor, at::Tensor> (*)(
const at::Tensor& u,
const at::Tensor& delta,
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor& C,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& z,
const c10::optional<at::Tensor>& delta_bias,
bool delta_softplus,
bool return_last_state);
using selective_state_update_fn = at::Tensor (*)(
const at::Tensor& state,
const at::Tensor& x,
const at::Tensor& dt,
const at::Tensor& A,
const at::Tensor& B,
const at::Tensor& C,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& z,
const c10::optional<at::Tensor>& dt_bias,
bool dt_softplus);
IPEX_DECLARE_DISPATCH(selective_scan_kernel_fn, selective_scan_kernel_stub);
IPEX_DECLARE_DISPATCH(
selective_state_update_fn,
selective_state_update_kernel_stub);

} // namespace cpu
} // namespace torch_ipex
89 changes: 89 additions & 0 deletions csrc/cpu/aten/kernels/CausalConvKrnl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include <aten/Conv.h>
#include "mkl.h"
#include "vec/vec.h"

namespace torch_ipex {
namespace cpu {
namespace {
template <typename T>
std::tuple<at::Tensor, at::Tensor> causal_conv1d_update_kernel_inner(
const at::Tensor& hidden_states,
const at::Tensor& conv_states,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& conv_bias,
bool silu_activation) {
auto bs = conv_states.size(0);
auto channels = conv_states.size(1);
auto kernel_size = conv_states.size(2);
auto has_bias = conv_bias.has_value();
auto bias_ptr = has_bias ? conv_bias.value().data_ptr<T>() : nullptr;
auto conv_states_ptr = conv_states.data_ptr<T>();
auto conv_weights_ptr = conv_weights.data_ptr<T>();
auto hidden_states_ptr = hidden_states.data_ptr<T>();
auto hidden_states_strideB = hidden_states.stride(0);
auto hidden_states_strideC = hidden_states.stride(1);
auto conv_states_strideB = conv_states.stride(0);
auto conv_states_strideC = conv_states.stride(1);
auto conv_states_strideK = conv_states.stride(2);
auto conv_weights_strideC = conv_weights.stride(0);
#pragma omp parallel for collapse(2)
for (auto bi = 0; bi < bs; bi++) {
for (auto ci = 0; ci < channels; ci++) {
auto conv_weights_start = ci * conv_weights_strideC;
float out = 0.0f;
auto conv_states_start =
bi * conv_states_strideB + ci * conv_states_strideC;
for (auto k = 1; k < kernel_size; k++) {
auto conv_states_idx = conv_states_start + k * conv_states_strideK;
out += conv_weights_ptr[conv_weights_start + k - 1] *
conv_states_ptr[conv_states_idx];
conv_states_ptr[conv_states_idx - conv_states_strideK] =
conv_states_ptr[conv_states_idx];
}
auto hidden_states_idx =
bi * hidden_states_strideB + ci * hidden_states_strideC;
out += hidden_states_ptr[hidden_states_idx] *
conv_weights_ptr[conv_weights_start + kernel_size - 1];
conv_states_ptr
[conv_states_start + (kernel_size - 1) * conv_states_strideK] =
hidden_states_ptr[hidden_states_idx];
if (has_bias) {
out += bias_ptr[ci];
}
if (silu_activation) {
out = out / (1 + expf(-out));
}
hidden_states_ptr[hidden_states_idx] = out;
}
}
return std::make_tuple(std::move(hidden_states), std::move(conv_states));
}

std::tuple<at::Tensor, at::Tensor> causal_conv1d_update_kernel_impl(
const at::Tensor& hidden_states,
const at::Tensor& conv_states,
const at::Tensor& conv_weights,
const c10::optional<at::Tensor>& conv_bias,
bool silu_activation) {
if (hidden_states.scalar_type() == at::ScalarType::Float) {
return causal_conv1d_update_kernel_inner<float>(
hidden_states, conv_states, conv_weights, conv_bias, silu_activation);
} else if (hidden_states.scalar_type() == at::ScalarType::BFloat16) {
return causal_conv1d_update_kernel_inner<at::BFloat16>(
hidden_states, conv_states, conv_weights, conv_bias, silu_activation);
} else if (hidden_states.scalar_type() == at::ScalarType::Half) {
return causal_conv1d_update_kernel_inner<at::Half>(
hidden_states, conv_states, conv_weights, conv_bias, silu_activation);
} else {
TORCH_CHECK(
false,
"Only support bfloat16, float16 and float for causal_conv1d_update");
}
}
} // anonymous namespace
IPEX_REGISTER_DISPATCH(
causal_conv1d_update_kernel_stub,
&causal_conv1d_update_kernel_impl);

} // namespace cpu
} // namespace torch_ipex
Loading

0 comments on commit 950e509

Please sign in to comment.