Skip to content

Commit

Permalink
Merge pull request #283 from cbourjau/fix-default-complex
Browse files Browse the repository at this point in the history
Fix way to determine default_complex
  • Loading branch information
asmeurer authored Aug 23, 2024
2 parents db95e67 + 4da61e5 commit 4caff28
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
default_float = xp.asarray(float()).dtype
if default_float not in real_float_dtypes:
warn(f"inferred default float is {default_float!r}, which is not a float")
if api_version > "2021.12":
if api_version > "2021.12" and ({'complex64', 'complex128'} - set(skip_dtypes)):
default_complex = xp.asarray(complex()).dtype
if default_complex not in complex_dtypes:
warn(
Expand Down
4 changes: 2 additions & 2 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
real_floating_dtypes = sampled_from(dh.real_float_dtypes)
numeric_dtypes = sampled_from(dh.numeric_dtypes)
# Note: this always returns complex dtypes, even if api_version < 2022.12
complex_dtypes = sampled_from(dh.complex_dtypes)
complex_dtypes: SearchStrategy[Any] | None = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else None

def all_floating_dtypes() -> SearchStrategy[DataType]:
strat = floating_dtypes
if api_version >= "2022.12":
if api_version >= "2022.12" and complex_dtypes is not None:
strat |= complex_dtypes
return strat

Expand Down

0 comments on commit 4caff28

Please sign in to comment.