You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "rok (via GitHub)" <gi...@apache.org> on 2023/04/06 18:48:57 UTC

[GitHub] [arrow] rok commented on pull request #34797: GH-34796: [C++] Add FromTensor, ToTensor and strides methods to FixedShapeTensorArray

rok commented on PR #34797:
URL: https://github.com/apache/arrow/pull/34797#issuecomment-1499479421

   A change introduced late in #8510 (and then removed) was that `ComputeRowMajorStrides` is replaced with a general `ComputeStrides` to calculate strides of permuted tensors which should help matching pytorch behavior:
   
   ```python
   import torch
   
   def print_strides(shape, permutation):
       permutation2 = [p - 1 for p in permutation[1:]]
       shape2 = shape[1:]
       x = torch.randn(shape).permute(permutation)
       y = torch.randn(shape2).permute(permutation2)
       strides = [z * 8 for z in x.stride()]
       strides2 = [z * 8 for z in y.stride()]
       
       print("[full tensor] shape:", shape, "permutation:", permutation, "strides:", strides)
       print("[cell tensor] shape:", shape2, "permutation:", permutation2, "strides:", strides2)
   
   shapes = [
       (3, 3, 4),
       (3, 3, 4),
       (3, 4, 3),
       (3, 4, 3),
   ]
   permutations = [
       (0, 1, 2),
       (0, 2, 1),
       (0, 1, 2),
       (0, 2, 1),
   ]
   
   for shape, permutation in zip(shapes, permutations):
       print_strides(shape, permutation)
   ```
   
   ```
   [full tensor] shape: (3, 3, 4) permutation: (0, 1, 2) strides: [96, 32, 8]
   [cell tensor] shape: (3, 4) permutation: [0, 1] strides: [32, 8]
   [full tensor] shape: (3, 3, 4) permutation: (0, 2, 1) strides: [96, 8, 32]
   [cell tensor] shape: (3, 4) permutation: [1, 0] strides: [8, 32]
   [full tensor] shape: (3, 4, 3) permutation: (0, 1, 2) strides: [96, 24, 8]
   [cell tensor] shape: (4, 3) permutation: [0, 1] strides: [24, 8]
   [full tensor] shape: (3, 4, 3) permutation: (0, 2, 1) strides: [96, 8, 24]
   [cell tensor] shape: (4, 3) permutation: [1, 0] strides: [8, 24]
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org