Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup ops/transformer/inference tests #6830

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_add_reference(activations, bias):
return activations + bias
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
from deepspeed.utils.torch import required_torch_version
from .inference_test_utils import allclose, get_dtypes
from packaging import version as pkg_version

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
Expand All @@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias):
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_bias_gelu(batch, sequence, channels, dtype):
if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"):
if not required_torch_version(min_version=1.12):
pytest.skip("gelu implementation matches only after torch 1.12")

activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name())
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down
Loading