Skip to content

Conversation

@Aravind-11
Copy link

Description

Fixes #854 - linspace now correctly handles int64 dtype

Changes

  • Modified aten_linspace to compute in floating-point then cast to target dtype
  • This matches PyTorch's behavior and fixes integer division precision loss

Testing

Manually verified: linspace(0, 10, 5, dtype=int64) now produces correct output [0, 2, 5, 7, 10]

Questions

Where should I add automated test cases for this fix? Happy to add tests wherever you suggest!

@Aravind-11
Copy link
Author

Description

Fixes #854 - linspace now correctly handles int64 dtype

Changes

  • Modified aten_linspace to compute in floating-point then cast to target dtype
  • This matches PyTorch's behavior and fixes integer division precision loss

Testing

Manually verified: linspace(0, 10, 5, dtype=int64) now produces correct output [0, 2, 5, 7, 10]

Questions

Where should I add automated test cases for this fix? Happy to add tests wherever you suggest!

Who can review : @justinchuby

@codecov
Copy link

codecov bot commented Nov 16, 2025

Codecov Report

❌ Patch coverage is 72.72727% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.08%. Comparing base (5583f96) to head (beacf26).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
onnxscript/function_libs/torch_lib/ops/core.py 72.72% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2693      +/-   ##
==========================================
- Coverage   70.11%   70.08%   -0.04%     
==========================================
  Files         226      226              
  Lines       27230    27282      +52     
  Branches     2748     2755       +7     
==========================================
+ Hits        19092    19120      +28     
- Misses       7193     7213      +20     
- Partials      945      949       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby
Copy link
Collaborator

Thanks. Could you unskip the tests:

@Aravind-11
Copy link
Author

Thanks. Could you unskip the tests:

Thank you for reviewing! Done.

@Aravind-11
Copy link
Author

Thanks. Could you unskip the tests:

Thank you for reviewing! Done.

hi @justinchuby , are there any updates on this that you could let me know. Thank you!

@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Dec 9, 2025
@Aravind-11
Copy link
Author

Aravind-11 commented Dec 9, 2025

Hi @justinchuby , the CI is failing and I updated the branch..could you approve the workflows for testing again? Thank you!

@Aravind-11
Copy link
Author

Hi @justinchuby , CUDA tests fail because PyTorch itself gives different results on CPU vs CUDA for integer linspace. For example, torch.linspace(4.3, -3, 50, dtype=torch.int64) returns different values at certain indices on CPU vs CUDA.

@justinchuby
Copy link
Collaborator

Thanks. In CI we only run cpu tests so we should be ok.

@Aravind-11
Copy link
Author

Thanks. In CI we only run cpu tests so we should be ok.

Got it. Thank you. Could you please approve the tests ?

@justinchuby justinchuby changed the title Fixes #854 [torchlib] Fix linspace implementation for int64 Dec 11, 2025
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Dec 11, 2025
@justinchuby justinchuby self-assigned this Dec 11, 2025
@Aravind-11
Copy link
Author

Thank you for the approval. Let me know if anything else is needed from my side.

@justinchuby
Copy link
Collaborator

I just realized you are using double precision. Does float32 work? Or is float64 required?

@Aravind-11
Copy link
Author

I tested both precisions, and float64 is required for correctness.
With float32, we get precision errors at certain indices:

# Index 21: float32 gives 0.999999761... → truncates to 0
#           float64 gives 1.000000000000... → truncates to 1 

@Aravind-11
Copy link
Author

this is the code i used to test

import torch
import numpy as np

def test_precision(start, end, steps, dtype_name):
    print(f"\n{'='*60}")
    print(f"linspace({start}, {end}, {steps}) with {dtype_name}")
    print(f"{'='*60}")
    
    start_int, end_int = int(start), int(end)
    
    
    step_f32 = np.float32((end_int - start_int) / (steps - 1))
    indices_f32 = np.arange(steps, dtype=np.float32)
    forward_f32 = start_int + step_f32 * indices_f32
    backward_f32 = end_int - step_f32 * (steps - 1 - indices_f32)
    result_f32 = np.where(indices_f32 < steps/2, forward_f32, backward_f32).astype(np.int64)
    
    step_f64 = np.float64((end_int - start_int) / (steps - 1))
    indices_f64 = np.arange(steps, dtype=np.float64)
    forward_f64 = start_int + step_f64 * indices_f64
    backward_f64 = end_int - step_f64 * (steps - 1 - indices_f64)
    result_f64 = np.where(indices_f64 < steps/2, forward_f64, backward_f64).astype(np.int64)
    
    torch_result = torch.linspace(start, end, steps, dtype=torch.int64).numpy()
    
    match_f32 = np.array_equal(result_f32, torch_result)
    match_f64 = np.array_equal(result_f64, torch_result)
    
    print(f"Float32 matches PyTorch: {match_f32}")
    print(f"Float64 matches PyTorch: {match_f64}")
    
    if not match_f32:
        diff_indices = np.where(result_f32 != torch_result)[0]
        print(f"\nFloat32 differences at {len(diff_indices)} indices: {diff_indices[:10]}")
        for idx in diff_indices[:3]:
            print(f"  Index {idx}: f32={result_f32[idx]}, f64={result_f64[idx]}, pytorch={torch_result[idx]}")
            print(f"    f32_float={forward_f32[idx] if idx < steps/2 else backward_f32[idx]:.15f}")
            print(f"    f64_float={forward_f64[idx] if idx < steps/2 else backward_f64[idx]:.15f}")

test_precision(4.3, -3, 50, "int64")
test_precision(0, 7, 50, "int64")
test_precision(50, 0, 50, "int64")

@justinchuby
Copy link
Collaborator

Thanks

@Aravind-11
Copy link
Author

uh oh .. looks like we re back to square one 😄 , please let me know what to do @justinchuby

@justinchuby
Copy link
Collaborator

Thanks, I will revert some of my changes

@Aravind-11
Copy link
Author

Got it. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

[torchlib] linspace results do not match with PyTorch when dtype is int64

3 participants