Add support for ROCm 7.0 and 7.1 (#17681)

## Summary

Closes #17674.
This commit is contained in:
Charlie Marsh
2026-01-24 10:54:03 -05:00
committed by GitHub
parent 73410fad33
commit d2b0f1dc23
3 changed files with 78 additions and 1 deletions
+3
View File
@@ -296,6 +296,7 @@ pub enum AmdGpuArchitecture {
Gfx908,
Gfx90a,
Gfx942,
Gfx950,
Gfx1030,
Gfx1100,
Gfx1101,
@@ -314,6 +315,7 @@ impl FromStr for AmdGpuArchitecture {
"gfx908" => Ok(Self::Gfx908),
"gfx90a" => Ok(Self::Gfx90a),
"gfx942" => Ok(Self::Gfx942),
"gfx950" => Ok(Self::Gfx950),
"gfx1030" => Ok(Self::Gfx1030),
"gfx1100" => Ok(Self::Gfx1100),
"gfx1101" => Ok(Self::Gfx1101),
@@ -333,6 +335,7 @@ impl std::fmt::Display for AmdGpuArchitecture {
Self::Gfx908 => write!(f, "gfx908"),
Self::Gfx90a => write!(f, "gfx90a"),
Self::Gfx942 => write!(f, "gfx942"),
Self::Gfx950 => write!(f, "gfx950"),
Self::Gfx1030 => write!(f, "gfx1030"),
Self::Gfx1100 => write!(f, "gfx1100"),
Self::Gfx1101 => write!(f, "gfx1101"),
+65 -1
View File
@@ -113,6 +113,14 @@ pub enum TorchMode {
Cu90,
/// Use the PyTorch index for CUDA 8.0.
Cu80,
/// Use the PyTorch index for ROCm 7.1.
#[serde(rename = "rocm7.1")]
#[cfg_attr(feature = "clap", clap(name = "rocm7.1"))]
Rocm71,
/// Use the PyTorch index for ROCm 7.0.
#[serde(rename = "rocm7.0")]
#[cfg_attr(feature = "clap", clap(name = "rocm7.0"))]
Rocm70,
/// Use the PyTorch index for ROCm 6.4.
#[serde(rename = "rocm6.4")]
#[cfg_attr(feature = "clap", clap(name = "rocm6.4"))]
@@ -276,6 +284,8 @@ impl TorchStrategy {
TorchMode::Cu91 => TorchBackend::Cu91,
TorchMode::Cu90 => TorchBackend::Cu90,
TorchMode::Cu80 => TorchBackend::Cu80,
TorchMode::Rocm71 => TorchBackend::Rocm71,
TorchMode::Rocm70 => TorchBackend::Rocm70,
TorchMode::Rocm64 => TorchBackend::Rocm64,
TorchMode::Rocm63 => TorchBackend::Rocm63,
TorchMode::Rocm624 => TorchBackend::Rocm624,
@@ -527,6 +537,8 @@ pub enum TorchBackend {
Cu91,
Cu90,
Cu80,
Rocm71,
Rocm70,
Rocm64,
Rocm63,
Rocm624,
@@ -659,6 +671,14 @@ impl TorchBackend {
TorchSource::PyTorch => &PYTORCH_CU80_INDEX_URL,
TorchSource::Pyx => &PYX_CU80_INDEX_URL,
},
Self::Rocm71 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM71_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM71_INDEX_URL,
},
Self::Rocm70 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM70_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM70_INDEX_URL,
},
Self::Rocm64 => match source {
TorchSource::PyTorch => &PYTORCH_ROCM64_INDEX_URL,
TorchSource::Pyx => &PYX_ROCM64_INDEX_URL,
@@ -790,6 +810,8 @@ impl TorchBackend {
Self::Cu91 => Some(Version::new([9, 1])),
Self::Cu90 => Some(Version::new([9, 0])),
Self::Cu80 => Some(Version::new([8, 0])),
Self::Rocm71 => None,
Self::Rocm70 => None,
Self::Rocm64 => None,
Self::Rocm63 => None,
Self::Rocm624 => None,
@@ -841,6 +863,8 @@ impl TorchBackend {
Self::Cu91 => None,
Self::Cu90 => None,
Self::Cu80 => None,
Self::Rocm71 => Some(Version::new([7, 1])),
Self::Rocm70 => Some(Version::new([7, 0])),
Self::Rocm64 => Some(Version::new([6, 4])),
Self::Rocm63 => Some(Version::new([6, 3])),
Self::Rocm624 => Some(Version::new([6, 2, 4])),
@@ -895,6 +919,8 @@ impl FromStr for TorchBackend {
"cu91" => Ok(Self::Cu91),
"cu90" => Ok(Self::Cu90),
"cu80" => Ok(Self::Cu80),
"rocm7.1" => Ok(Self::Rocm71),
"rocm7.0" => Ok(Self::Rocm70),
"rocm6.4" => Ok(Self::Rocm64),
"rocm6.3" => Ok(Self::Rocm63),
"rocm6.2.4" => Ok(Self::Rocm624),
@@ -1010,9 +1036,35 @@ static WINDOWS_CUDA_VERSIONS: LazyLock<[(TorchBackend, Version); 26]> = LazyLock
///
/// AMD also provides a compatibility matrix: <https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html>;
/// however, this list includes a broader array of GPUs than those in the matrix.
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 55]> =
static LINUX_AMD_GPU_DRIVERS: LazyLock<[(TorchBackend, AmdGpuArchitecture); 79]> =
LazyLock::new(|| {
[
// ROCm 7.1
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx950),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm71, AmdGpuArchitecture::Gfx1201),
// ROCm 7.0
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx906),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx908),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx90a),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx942),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx950),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1030),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1100),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1101),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1102),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1200),
(TorchBackend::Rocm70, AmdGpuArchitecture::Gfx1201),
// ROCm 6.4
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx900),
(TorchBackend::Rocm64, AmdGpuArchitecture::Gfx906),
@@ -1131,6 +1183,10 @@ static PYTORCH_CU90_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu90").unwrap());
static PYTORCH_CU80_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/cu80").unwrap());
static PYTORCH_ROCM71_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm7.1").unwrap());
static PYTORCH_ROCM70_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm7.0").unwrap());
static PYTORCH_ROCM64_INDEX_URL: LazyLock<IndexUrl> =
LazyLock::new(|| IndexUrl::from_str("https://download.pytorch.org/whl/rocm6.4").unwrap());
static PYTORCH_ROCM63_INDEX_URL: LazyLock<IndexUrl> =
@@ -1281,6 +1337,14 @@ static PYX_CU80_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/cu80")).unwrap()
});
static PYX_ROCM71_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm7.1")).unwrap()
});
static PYX_ROCM70_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm7.0")).unwrap()
});
static PYX_ROCM64_INDEX_URL: LazyLock<IndexUrl> = LazyLock::new(|| {
let api_base_url = &*PYX_API_BASE_URL;
IndexUrl::from_str(&format!("{api_base_url}/simple/astral-sh/rocm6.4")).unwrap()
+10
View File
@@ -2397,6 +2397,16 @@
"type": "string",
"const": "cu80"
},
{
"description": "Use the PyTorch index for ROCm 7.1.",
"type": "string",
"const": "rocm7.1"
},
{
"description": "Use the PyTorch index for ROCm 7.0.",
"type": "string",
"const": "rocm7.0"
},
{
"description": "Use the PyTorch index for ROCm 6.4.",
"type": "string",