Introduce `MmuConfig`, the trait that ties the entry-operation traits (`PteOps`, `PdeOps`, `DualPdeOps`) together with the version-specific constants and helpers.
`MmuV2` and `MmuV3` are zero-sized marker structs that implement `MmuConfig` for Turing/Ampere/Ada and Hopper/Blackwell respectively. Dispatch is fully resolved at compile time through these markers, so version-specific code is selected without runtime overhead and without wrapper enums. This enables version-agnostic page-table operations while keeping version-specific implementation details encapsulated in the `ver2` and `ver3` modules. Signed-off-by: Joel Fernandes <[email protected]> --- drivers/gpu/nova-core/mm/pagetable.rs | 109 ++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/drivers/gpu/nova-core/mm/pagetable.rs b/drivers/gpu/nova-core/mm/pagetable.rs index 3cc546f94fdb..38f4f0c6e8ce 100644 --- a/drivers/gpu/nova-core/mm/pagetable.rs +++ b/drivers/gpu/nova-core/mm/pagetable.rs @@ -19,6 +19,7 @@ use crate::mm::{ pramin, Pfn, + VirtualAddress, VramAddress, // }; @@ -196,6 +197,114 @@ fn write(&self, window: &mut pramin::PraminWindow<'_>, addr: VramAddress) -> Res } } +/// MMU configuration trait -- encodes version-specific constants and types. +pub(super) trait MmuConfig: 'static { + /// Page Table Entry type. + type Pte: PteOps; + /// Page Directory Entry type. + type Pde: PdeOps; + /// Dual Page Directory Entry type (128-bit). + type DualPde: DualPdeOps; + + /// PDE levels (excluding PTE level) for page table walking. + const PDE_LEVELS: &'static [PageTableLevel]; + /// PTE level for this MMU version. + const PTE_LEVEL: PageTableLevel; + /// Dual PDE level (128-bit entries) for this MMU version. + const DUAL_PDE_LEVEL: PageTableLevel; + + /// Get the number of entries per page table page for a given level. + fn entries_per_page(level: PageTableLevel) -> usize; + + /// Extract the page table index at `level` from `va`. + fn level_index(va: VirtualAddress, level: u64) -> u64; + + /// Get the entry size in bytes for a given level. + fn entry_size(level: PageTableLevel) -> usize { + if level == Self::DUAL_PDE_LEVEL { + 16 // 128-bit dual PDE + } else { + 8 // 64-bit PDE/PTE + } + } + + /// Compute upper bound on page table pages needed for `num_virt_pages`. + /// + /// Walks from PTE level up through PDE levels, accumulating the tree. + fn pt_pages_upper_bound(num_virt_pages: usize) -> usize { + let mut total = 0; + + // PTE pages at the leaf level. + let pte_epp = Self::entries_per_page(Self::PTE_LEVEL); + let mut pages_at_level = num_virt_pages.div_ceil(pte_epp); + total += pages_at_level; + + // Walk PDE levels bottom-up (reverse of PDE_LEVELS). + for &level in Self::PDE_LEVELS.iter().rev() { + let epp = Self::entries_per_page(level); + + // How many pages at this level do we need to point to + // the previous pages_at_level? + pages_at_level = pages_at_level.div_ceil(epp); + total += pages_at_level; + } + + total + } +} + +/// Marker struct for MMU v2 (Turing/Ampere/Ada). +pub(super) struct MmuV2; + +impl MmuConfig for MmuV2 { + type Pte = ver2::Pte; + type Pde = ver2::Pde; + type DualPde = ver2::DualPde; + + const PDE_LEVELS: &'static [PageTableLevel] = ver2::PDE_LEVELS; + const PTE_LEVEL: PageTableLevel = ver2::PTE_LEVEL; + const DUAL_PDE_LEVEL: PageTableLevel = ver2::DUAL_PDE_LEVEL; + + fn entries_per_page(level: PageTableLevel) -> usize { + // TODO: Calculate these values from the bitfield dynamically + // instead of hardcoding them. + match level { + PageTableLevel::Pdb => 4, // PD3 root: bits [48:47] = 2 bits + PageTableLevel::L3 => 256, // PD0 dual: bits [28:21] = 8 bits + _ => 512, // PD2, PD1, PT: 9 bits each + } + } + + fn level_index(va: VirtualAddress, level: u64) -> u64 { + ver2::VirtualAddressV2::new(va).level_index(level) + } +} + +/// Marker struct for MMU v3 (Hopper and later). +pub(super) struct MmuV3; + +impl MmuConfig for MmuV3 { + type Pte = ver3::Pte; + type Pde = ver3::Pde; + type DualPde = ver3::DualPde; + + const PDE_LEVELS: &'static [PageTableLevel] = ver3::PDE_LEVELS; + const PTE_LEVEL: PageTableLevel = ver3::PTE_LEVEL; + const DUAL_PDE_LEVEL: PageTableLevel = ver3::DUAL_PDE_LEVEL; + + fn entries_per_page(level: PageTableLevel) -> usize { + match level { + PageTableLevel::Pdb => 2, // PDE4 root: bit [56] = 1 bit, 2 entries + PageTableLevel::L4 => 256, // PDE0 dual: bits [28:21] = 8 bits + _ => 512, // PDE3, PDE2, PDE1, PT: 9 bits each + } + } + + fn level_index(va: VirtualAddress, level: u64) -> u64 { + ver3::VirtualAddressV3::new(va).level_index(level) + } +} + /// Memory aperture for Page Table Entries (`PTE`s). /// /// Determines which memory region the `PTE` points to. -- 2.34.1
