-
Notifications
You must be signed in to change notification settings - Fork 96
[torchlib] Fix linspace implementation for int64 #2693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Who can review : @justinchuby |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2693 +/- ##
==========================================
- Coverage 70.11% 70.08% -0.04%
==========================================
Files 226 226
Lines 27230 27282 +52
Branches 2748 2755 +7
==========================================
+ Hits 19092 19120 +28
- Misses 7193 7213 +20
- Partials 945 949 +4 ☔ View full report in Codecov by Sentry. |
|
Thanks. Could you unskip the tests:
|
Thank you for reviewing! Done. |
hi @justinchuby , are there any updates on this that you could let me know. Thank you! |
|
Hi @justinchuby , the CI is failing and I updated the branch..could you approve the workflows for testing again? Thank you! |
|
Hi @justinchuby , CUDA tests fail because PyTorch itself gives different results on CPU vs CUDA for integer linspace. For example, |
|
Thanks. In CI we only run cpu tests so we should be ok. |
Got it. Thank you. Could you please approve the tests ? |
|
Thank you for the approval. Let me know if anything else is needed from my side. |
|
I just realized you are using double precision. Does float32 work? Or is float64 required? |
|
I tested both precisions, and float64 is required for correctness. # Index 21: float32 gives 0.999999761... → truncates to 0
# float64 gives 1.000000000000... → truncates to 1 |
|
this is the code i used to test import torch
import numpy as np
def test_precision(start, end, steps, dtype_name):
print(f"\n{'='*60}")
print(f"linspace({start}, {end}, {steps}) with {dtype_name}")
print(f"{'='*60}")
start_int, end_int = int(start), int(end)
step_f32 = np.float32((end_int - start_int) / (steps - 1))
indices_f32 = np.arange(steps, dtype=np.float32)
forward_f32 = start_int + step_f32 * indices_f32
backward_f32 = end_int - step_f32 * (steps - 1 - indices_f32)
result_f32 = np.where(indices_f32 < steps/2, forward_f32, backward_f32).astype(np.int64)
step_f64 = np.float64((end_int - start_int) / (steps - 1))
indices_f64 = np.arange(steps, dtype=np.float64)
forward_f64 = start_int + step_f64 * indices_f64
backward_f64 = end_int - step_f64 * (steps - 1 - indices_f64)
result_f64 = np.where(indices_f64 < steps/2, forward_f64, backward_f64).astype(np.int64)
torch_result = torch.linspace(start, end, steps, dtype=torch.int64).numpy()
match_f32 = np.array_equal(result_f32, torch_result)
match_f64 = np.array_equal(result_f64, torch_result)
print(f"Float32 matches PyTorch: {match_f32}")
print(f"Float64 matches PyTorch: {match_f64}")
if not match_f32:
diff_indices = np.where(result_f32 != torch_result)[0]
print(f"\nFloat32 differences at {len(diff_indices)} indices: {diff_indices[:10]}")
for idx in diff_indices[:3]:
print(f" Index {idx}: f32={result_f32[idx]}, f64={result_f64[idx]}, pytorch={torch_result[idx]}")
print(f" f32_float={forward_f32[idx] if idx < steps/2 else backward_f32[idx]:.15f}")
print(f" f64_float={forward_f64[idx] if idx < steps/2 else backward_f64[idx]:.15f}")
test_precision(4.3, -3, 50, "int64")
test_precision(0, 7, 50, "int64")
test_precision(50, 0, 50, "int64") |
|
Thanks |
|
uh oh .. looks like we re back to square one 😄 , please let me know what to do @justinchuby |
|
Thanks, I will revert some of my changes |
|
Got it. Thank you. |
Description
Fixes #854 - linspace now correctly handles int64 dtype
Changes
aten_linspaceto compute in floating-point then cast to target dtypeTesting
Manually verified:
linspace(0, 10, 5, dtype=int64)now produces correct output[0, 2, 5, 7, 10]Questions
Where should I add automated test cases for this fix? Happy to add tests wherever you suggest!