diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index 74964e12fb..a120d58bbd 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -25,6 +25,7 @@ num-traits = "0.2.14" tract-data = { path = "../data" } paste = "1.0.5" scan_fmt = "0.2.6" +strength_reduce = "0.2" [build-dependencies] cc = "1.0.69" diff --git a/linalg/src/frame/pack.rs b/linalg/src/frame/pack.rs index 74bed74cae..086008b984 100644 --- a/linalg/src/frame/pack.rs +++ b/linalg/src/frame/pack.rs @@ -1,19 +1,23 @@ use std::fmt::Debug; use std::marker::PhantomData; use std::ops::Range; +use strength_reduce::StrengthReducedUsize; use tract_data::internal::*; -#[derive(Clone, Debug, Eq, PartialEq, Educe)] -#[educe(Hash)] +#[derive(Clone, Debug, Educe)] +#[educe(Hash, PartialEq, Eq)] pub struct Packer { pub r: usize, alignment: usize, end_padding_record: usize, + #[educe(PartialEq(ignore))] + #[educe(Hash(ignore))] + r_reduced: StrengthReducedUsize, } impl Packer { pub fn new(nr: usize, alignment: usize, end_padding_record: usize) -> Packer { - Packer { r: nr, alignment, end_padding_record } + Packer { r: nr, alignment, end_padding_record, r_reduced: StrengthReducedUsize::new(nr) } } pub fn alignment(&self) -> usize { @@ -126,7 +130,7 @@ impl Packer { k: usize, mn: usize, ) -> KOutWriter<'p, T> { - KOutWriter::new(pb, self.r, mn, k) + KOutWriter::new(pb, self.r_reduced, mn, k) } pub fn write_single_panel_with_k_outer<'p, T: Copy + Debug>( @@ -142,7 +146,7 @@ impl Packer { k: usize, mn: usize, ) -> KInWriter<'p, T> { - KInWriter::new(pb, self.r, mn, k) + KInWriter::new(pb, self.r_reduced, mn, k) } } @@ -201,19 +205,25 @@ impl<'p, T> KOutWriter<'p, T> where T: Copy + std::fmt::Debug, { - pub fn new(ptr: *mut T, panel_width: usize, mn: usize, k: usize) -> KOutWriter<'p, T> { - let panels = (mn + panel_width - 1) / panel_width; - let last_panel_width = mn - (panels - 1) * panel_width; + pub fn new( + ptr: *mut T, + panel_width_reduced: StrengthReducedUsize, + mn: usize, + k: usize, + ) -> KOutWriter<'p, T> { + let (div, rem) = StrengthReducedUsize::div_rem(mn, panel_width_reduced); + let (panels, last_panel_width) = + if rem > 0 { (div + 1, rem) } else { (div, panel_width_reduced.get()) }; KOutWriter { ptr, panels, - panel_width, + panel_width: panel_width_reduced.get(), last_panel_width, - remain: if panels > 1 { panel_width } else { last_panel_width }, + remain: if panels > 1 { panel_width_reduced.get() } else { last_panel_width }, current_panel: 0, - next_panel: ((k - 1) * panel_width) as isize, - next_lane: panel_width as isize - - ((last_panel_width + (panels - 1) * panel_width * k) as isize), + next_panel: ((k - 1) * panel_width_reduced.get()) as isize, + next_lane: panel_width_reduced.get() as isize + - ((last_panel_width + (panels - 1) * panel_width_reduced.get() * k) as isize), _phantom: PhantomData, } } @@ -269,20 +279,26 @@ impl<'p, T> KInWriter<'p, T> where T: Copy + Debug, { - pub fn new(ptr: *mut T, panel_width: usize, mn: usize, k: usize) -> KInWriter<'p, T> { - let panels = (mn + panel_width - 1) / panel_width; - let last_panel_width = mn - (panels - 1) * panel_width; + pub fn new( + ptr: *mut T, + panel_width_reduced: StrengthReducedUsize, + mn: usize, + k: usize, + ) -> KInWriter<'p, T> { + let (div, rem) = StrengthReducedUsize::div_rem(mn, panel_width_reduced); + let (panels, last_panel_width) = + if rem > 0 { (div + 1, rem) } else { (div, panel_width_reduced.get()) }; KInWriter { ptr, k, panels, - panel_width, + panel_width: panel_width_reduced.get(), last_panel_width, remain_on_k: k, - remain_on_mn: if panels == 1 { last_panel_width } else { panel_width }, + remain_on_mn: if panels == 1 { last_panel_width } else { panel_width_reduced.get() }, current_panel: 0, - next_mn_offset: 1 - (k * panel_width) as isize, - next_panel_offset: 1 - panel_width as isize, + next_mn_offset: 1 - (k * panel_width_reduced.get()) as isize, + next_panel_offset: 1 - panel_width_reduced.get() as isize, _phantom: PhantomData, } } @@ -325,7 +341,7 @@ pub unsafe fn pack_mn_major( mn_range_bytes: Range, k_range: Range, ) { - let mnr = std::mem::size_of::(); + let mnr:usize = std::mem::size_of::(); let full_panes = mn_range_bytes.len() / mnr; let partial_pane = mn_range_bytes.len() % mnr; for k in 0..k_range.len() {