You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/30 03:31:53 UTC

[GitHub] [incubator-tvm] spectrometerHBH opened a new pull request #5483: [TIR][Printer] text format printer considering future parsing use

spectrometerHBH opened a new pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483


   ## Several Points
   1. Rename vars/size_vars/buffers to use name hints to identify them
   2. A node is either in META or not. If ref A is put in META, ref B is contained in A, then when we need to print ref B we print its META.
   3. Now only the `node` in `AttrStmt` may be put in META except that `node` is `StringImm/Var/Buffer/IterVar`.
   4. Print the Type of vars when we first encounter them on the tree.
   5. This PR hasn't combine relay's printer with tir's printer. I put several choices here for combined API for discussion.
   - add an `astext` member function for IRModule: we only do mix printing when printing an IRModule. Keep `relay.astext` and `tir.astext` separate for other uses.
   - merge relay's printer & tir's printer into one printer and use only one API `ir.astext`
   
   cc @tqchen @Hzfengsy 
   
   ## Several Examples
   ### Simple
   ```c++
   primfn(A0_1: handle, A1_1: handle, C_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "main"}
     buffers = {A1: Buffer(A1_2: handle, float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type="auto"),
              C: Buffer(C_2: handle, float32, [m, n], [stride_2: int32, stride_3: int32], type="auto"),
              A0: Buffer(A0_2: handle, float32, [m, n], [stride_4: int32, stride_5: int32], type="auto")}
     buffer_map = {C_1: C, A0_1: A0, A1_1: A1} {
     attr [B.v0: handle] "storage_scope" = "global";
     allocate(B.v0, float32, [n])  {
       attr [B.v1: handle] "storage_scope" = "global";
       allocate(B.v1, float32, [n])  {
         for (i: int32, 0, m) {
           for (j: int32, 0, n) {
             B.v0[j] = (load(float32, A0_2[((i*stride_4) + (j*stride_5))]) + float32(2))
             B.v1[j] = (load(float32, A0_2[((i*stride_4) + (j*stride_5))])*float32(3))
           }
           for (j_1: int32, 0, n) {
             C_2[((i*stride_2) + (j_1*stride_3))] = (load(float32, A1_2[((i*stride) + (j_1*stride_1))]) + load(float32, B.v0[j_1]))
           }
         }
       }
     }
   }
   ```
   ### Conv on GPU
   ```c++
   primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "cuda"}
     buffers = {W: Buffer(W_2: handle, float32, [3, 3, 256, 512], []),
              B: Buffer(B_2: handle, float32, [14, 14, 512, 256], []),
              A: Buffer(A_2: handle, float32, [14, 14, 256, 256], [])}
     buffer_map = {B_1: B, A_1: A, W_1: W} {
     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
     attr [B.local: handle] "storage_scope" = "local";
     allocate(B.local, float32, [64])  {
       attr [Apad.shared: handle] "storage_scope" = "shared";
       allocate(Apad.shared, float32, [512])  {
         attr [W.shared: handle] "storage_scope" = "shared";
         allocate(W.shared, float32, [512])  {
           attr [Apad.shared.local: handle] "storage_scope" = "local";
           allocate(Apad.shared.local, float32, [8])  {
             attr [W.shared.local: handle] "storage_scope" = "local";
             allocate(W.shared.local, float32, [8])  {
               attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 8;
               attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 4;
               attr [IterVar(threadIdx.y: int32, [0:8], "ThreadIndex", "threadIdx.y")] "thread_extent" = 8;
               attr [IterVar(threadIdx.x: int32, [0:8], "ThreadIndex", "threadIdx.x")] "thread_extent" = 8 {
                 for (ff.c.init: int32, 0, 4) {
                   for (nn.c.init: int32, 0, 4) {
                     B.local[((ff.c.init*4) + nn.c.init)] = float32(0)
                     B.local[(((ff.c.init*4) + nn.c.init) + 32)] = float32(0)
                     B.local[(((ff.c.init*4) + nn.c.init) + 16)] = float32(0)
                     B.local[(((ff.c.init*4) + nn.c.init) + 48)] = float32(0)
                   }
                 }
                 for (rc.outer: int32, 0, 32) {
                   for (ry: int32, 0, 3) {
                     for (rx: int32, 0, 3) {
                       for (ax3.inner.outer: int32, 0, 2) {
                         Apad.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer*4)), 1, 4)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + ry)) and ((floordiv(blockIdx.z, 14) + ry) < 15)) and (1 <= (rx + floormod(blockIdx.z, 14)))) and ((rx + floormod(blockIdx.z, 14)) < 15)), load(float32x4, A_2[ramp((((((((((ry*917504) + (blockIdx.z*65536)) + (rx*65536)) + (rc.outer*2048)) + (threadIdx.y*256)) + (blockIdx.x*64)) + (threadIdx.x*8)) + (ax3.inner.outer*4)) - 983040), 1, 4)]), broadcast(float32(0), 4)], float32x4, "pure_intrin", 0)
                       }
                       for (ax3.inner.outer_1: int32, 0, 2) {
                         W.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)] = load(float32x4, W_2[ramp((((((((ry*393216) + (rx*131072)) + (rc.outer*4096)) + (threadIdx.y*512)) + (blockIdx.y*64)) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)])
                       }
                       for (rc.inner: int32, 0, 8) {
                         for (ax3: int32, 0, 4) {
                           Apad.shared.local[ax3] = load(float32, Apad.shared[(((rc.inner*64) + (threadIdx.x*4)) + ax3)])
                           Apad.shared.local[(ax3 + 4)] = load(float32, Apad.shared[((((rc.inner*64) + (threadIdx.x*4)) + ax3) + 32)])
                         }
                         for (ax3_1: int32, 0, 4) {
                           W.shared.local[ax3_1] = load(float32, W.shared[(((rc.inner*64) + (threadIdx.y*4)) + ax3_1)])
                           W.shared.local[(ax3_1 + 4)] = load(float32, W.shared[((((rc.inner*64) + (threadIdx.y*4)) + ax3_1) + 32)])
                         }
                         for (ff.c: int32, 0, 4) {
                           for (nn.c: int32, 0, 4) {
                             B.local[((ff.c*4) + nn.c)] = (load(float32, B.local[((ff.c*4) + nn.c)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[ff.c])))
                             B.local[(((ff.c*4) + nn.c) + 32)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 32)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[(ff.c + 4)])))
                             B.local[(((ff.c*4) + nn.c) + 16)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 16)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[ff.c])))
                             B.local[(((ff.c*4) + nn.c) + 48)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 48)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[(ff.c + 4)])))
                           }
                         }
                       }
                     }
                   }
                 }
                 for (ff.inner.inner.inner: int32, 0, 4) {
                   for (nn.inner.inner.inner: int32, 0, 4) {
                     B_2[(((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)] = load(float32, B.local[((ff.inner.inner.inner*4) + nn.inner.inner.inner)])
                     B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8192)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 32)])
                     B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 32)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 16)])
                     B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8224)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 48)])
                   }
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   ### TensorCore for Conv
   ```c++
   primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "main"}
     buffers = {W: Buffer(W_2: handle, float16, [3, 3, 16, 32, 16, 16], []),
              A: Buffer(A_2: handle, float16, [16, 14, 14, 16, 16, 16], []),
              Conv: Buffer(Conv_2: handle, float32, [16, 14, 14, 32, 16, 16], [])}
     buffer_map = {A_1: A, Conv_1: Conv, W_1: W} {
     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
     attr [Conv.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
     allocate(Conv.wmma.accumulator, float32, [2048])  {
       attr [Apad.shared: handle] "storage_scope" = "shared";
       allocate(Apad.shared, float16, [12288])  {
         attr [W.shared: handle] "storage_scope" = "shared";
         allocate(W.shared, float16, [12288])  {
           attr [Apad.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
           allocate(Apad.shared.wmma.matrix_a, float16, [512])  {
             attr [W.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
             allocate(W.shared.wmma.matrix_b, float16, [1024])  {
               attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
               attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
               attr [IterVar(threadIdx.y: int32, [(nullptr)], "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
               attr [IterVar(threadIdx.z: int32, [(nullptr)], "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
                 for (n.c.init: int32, 0, 2) {
                   for (o.c.init: int32, 0, 4) {
                     eval(call("tvm_fill_fragment", [Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), float32(0)], handle, "intrin", 0))
                   }
                 }
                 for (ic.outer: int32, 0, 8) {
                   for (kh: int32, 0, 3) {
                     for (ax2: int32, 0, 3) {
                       for (ax3: int32, 0, 2) {
                         for (ax4.ax5.fused.outer: int32, 0, 8) {
                           attr [IterVar(threadIdx.x: int32, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
                           Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + kh)) and ((floordiv(blockIdx.z, 14) + kh) < 15)) and (1 <= (ax2 + floormod(blockIdx.z, 14)))) and ((ax2 + floormod(blockIdx.z, 14)) < 15)), load(float16, A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)]), float16(0)], float16, "pure_intrin", 0)
                         }
                       }
                     }
                     for (ax1: int32, 0, 3) {
                       for (ax2_1: int32, 0, 2) {
                         attr [IterVar(threadIdx.x, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
                         W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = load(float16x8, W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)])
                       }
                     }
                     for (ic.inner: int32, 0, 2) {
                       for (kw: int32, 0, 3) {
                         for (ax0: int32, 0, 2) {
                           eval(call("tvm_load_matrix_sync", [Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                         }
                         for (ax3_1: int32, 0, 4) {
                           eval(call("tvm_load_matrix_sync", [W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                         }
                         for (n.c: int32, 0, 2) {
                           for (o.c: int32, 0, 4) {
                             eval(call("tvm_mma_sync", [Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c)], handle, "intrin", 0))
                           }
                         }
                       }
                     }
                   }
                 }
                 for (n.inner: int32, 0, 2) {
                   for (o.inner: int32, 0, 4) {
                     eval(call("tvm_store_matrix_sync", [Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), call("tvm_access_ptr", [call("type_annotation", [], float32, "pure_intrin", 0), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                   }
                 }
               }
             }
           }
         }
       }
     }
   }
   ```
   ### GEMM on CPU
   ```c++
   primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
     attr = {"tir.noalias": bool(1), "global_symbol": "main"}
     buffers = {C: Buffer(C_2: handle, float32, [1024, 1024], []),
              B: Buffer(B_2: handle, float32, [1024, 1024], []),
              A: Buffer(A_2: handle, float32, [1024, 1024], [])}
     buffer_map = {C_1: C, A_1: A, B_1: B} {
     attr [packedB: handle] "storage_scope" = "global";
     allocate(packedB, float32x32, [32768])  {
       for (x: int32, 0, 32) "parallel" {
         for (y: int32, 0, 1024) {
           packedB[ramp(((x*32768) + (y*32)), 1, 32)] = load(float32x32, B_2[ramp(((y*1024) + (x*32)), 1, 32)])
         }
       }
       for (x.outer: int32, 0, 32) "parallel" {
         attr [C.global: handle] "storage_scope" = "global";
         allocate(C.global, float32, [1024])  {
           for (y.outer: int32, 0, 32) {
             for (x.c.init: int32, 0, 32) {
               C.global[ramp((x.c.init*32), 1, 32)] = broadcast(float32(0), 32)
             }
             for (k.outer: int32, 0, 256) {
               for (x.c: int32, 0, 32) {
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))]), 32)*load(float32x32, packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)])))
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)])))
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)])))
                 C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)])))
               }
             }
             for (x.inner: int32, 0, 32) {
               for (y.inner: int32, 0, 32) {
                 C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = load(float32, C.global[((x.inner*32) + y.inner)])
               }
             }
           }
         }
       }
     }
   }
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622203710


   > Some comments:
   > 
   > 1. There are more than one spaces before the left brace in the allocation line
   >    ```
   >    allocate(B.local, float32, [64])  {
   >    ```
   > 2. Can we use the same rule for the allocation stmt as the one for attr? Allocation stmt now will bring extra indentation
   > 3. ```
   >     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
   >    ```
   >    
   >    
   >    It is strange to print `nullptr` here especially in square brackets. Perhaps we can use `IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")]` or `IterVar(blockIdx.z: int32, , "ThreadIndex", "blockIdx.z")]`
   > 4. Considering future parsing use, we must print the dtype for every const number. But we may use some shorthand for common dtype. e.g. `2f` for float32, `2h` for float16(half), direct`2` for int32 (for here most integer numbers in schedule are int32). But still, keep the complete form for every type. e.g. `int8(2)`, `float64(2)`(or may be `fp64(2)`) , also, `float32(2)` is legal as well.
   
   1. fixed
   2. fixed. But here we implicitly assume that `Allocate` and `Attr` will have at least one child. Otherwise, for such a scenario
   ```c++
   attr...;
   attr...;
   attr...;
   ```
   We can not determine whether it is `attr|attr|attr` or `attr(attr)|attr`
   
   3. fixed


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626096064


   I found the problem. It is due to the different behavior of gcc5.4 and gcc 7.4
   ```c++
   doc << "(" << Print(op->a) << OpString << Print(op->b) << ")";
   ```
   
   gcc5.4 will execute Print(op->b) first.
   gcc7.4 will execute Print(op->a) first.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622071330


   - Allocate `float32x32` is not the actual data type. Perhaps we can just show allocate the flattened size since that is the semantics. 
   - eval is not necessary since they can be implied in a call. 
   - We might need to update call later to something like `@intrin.func_name(args)`


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r421900677



##########
File path: python/tvm/tir/parser.py
##########
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+"""Namespace for Tensor-level IR"""
+
+from . import _ffi_api
+
+
+def astext(code):

Review comment:
       I see. But why tir.Stmt is not a subclass of Node now?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418371548



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_var_[var] = val;
+    return val << ": " << Print(GetType(var));
+  }
+
+  Doc AllocBuf(const Buffer& buffer) {
+    const auto& it = memo_buf_.find(buffer);
+    if (it != memo_buf_.end()) {
+      return it->second;
+    }
+    std::string name = buffer->name;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "buf_" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_buf_[buffer] = val;
+    return val;
+  }
+
+  /*!
+   * \brief special method to render vectors of docs with a separator
+   * \param vec vector of docs
+   * \param sep separator
+   */
+  static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+    Doc seq;
+    if (vec.size() != 0) {
+      seq = vec[0];
+      for (size_t i = 1; i < vec.size(); i++) {
+        seq << sep << vec[i];
+      }
+    }
+    return seq;
+  }
+
+  /*!
+   * \brief dump meta info
+   * \return Doc with meta info
+   */
+  Doc DumpMeta() {
+    if (show_meta_) {
+      return Doc::Text("__tvm_meta__ = ")
+          << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
+    } else {
+      return Doc::Text("");
+    }
+  }
+
+  Doc PrintBody(const Stmt& body, bool indent = true) {
+    Doc doc;
+    if (body->IsInstance<SeqStmtNode>()) return Print(body);
+    doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+    return doc;
+  }
+};
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node.as<StmtNode>()) {

Review comment:
       ```suggestion
     if (node->IsInstance<StmtNode>()) {
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418400818



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_var_[var] = val;
+    return val << ": " << Print(GetType(var));
+  }
+
+  Doc AllocBuf(const Buffer& buffer) {
+    const auto& it = memo_buf_.find(buffer);
+    if (it != memo_buf_.end()) {
+      return it->second;
+    }
+    std::string name = buffer->name;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "buf_" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_buf_[buffer] = val;
+    return val;
+  }
+
+  /*!
+   * \brief special method to render vectors of docs with a separator
+   * \param vec vector of docs
+   * \param sep separator
+   */
+  static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+    Doc seq;
+    if (vec.size() != 0) {
+      seq = vec[0];
+      for (size_t i = 1; i < vec.size(); i++) {
+        seq << sep << vec[i];
+      }
+    }
+    return seq;
+  }
+
+  /*!
+   * \brief dump meta info
+   * \return Doc with meta info
+   */
+  Doc DumpMeta() {
+    if (show_meta_) {
+      return Doc::Text("__tvm_meta__ = ")
+          << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
+    } else {
+      return Doc::Text("");
+    }
+  }
+
+  Doc PrintBody(const Stmt& body, bool indent = true) {
+    Doc doc;
+    if (body->IsInstance<SeqStmtNode>()) return Print(body);
+    doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+    return doc;
+  }
+};
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node.as<StmtNode>()) {
+    return VisitStmt(Downcast<Stmt>(node));
+  } else if (node.as<PrimExprNode>()) {
+    return VisitExpr(Downcast<PrimExpr>(node));
+  } else if (node.as<TypeNode>()) {
+    return VisitType(Downcast<Type>(node));
+  } else if (node.as<PrimFuncNode>()) {
+    return PrintPrimFunc(Downcast<PrimFunc>(node));
+  } else if (node.as<IRModuleNode>()) {
+    return PrintIRModule(Downcast<IRModule>(node));
+  } else if (node.as<ArrayNode>()) {
+    return PrintArray(node.as<ArrayNode>());
+  } else if (node.as<IterVarNode>()) {
+    return PrintIterVar(node.as<IterVarNode>());
+  } else if (node.as<RangeNode>()) {
+    return PrintRange(node.as<RangeNode>());
+  } else if (node.as<BufferNode>()) {
+    return PrintBuffer(node.as<BufferNode>());
+  } else if (node.as<StringObj>()) {
+    return PrintString(node.as<StringObj>());
+  } else {
+    return this->meta_.GetMetaNode(node);
+  }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+  const auto* op = primFunc.operator->();
+  const auto& signature = op->func_type_annotation();
+  // collect Meta in DictAttr
+  for (const auto& it : primFunc->attrs->dict) {
+    meta_collector_.Collect(it.second);
+  }
+  // collect buffers in buffer_map
+  memo_var_.clear();
+  memo_buf_.clear();
+  for (const auto& it : op->buffer_map) {
+    memo_buf_[it.second] = AllocBuf(it.second);
+  }
+  // print PrimFunc
+  Doc doc;
+  doc << "primfn" << "(";
+  // print params and its type annotation
+  std::vector<Doc> params;
+  for (const auto& param : op->params) {
+    params.push_back(Print(param));
+  }
+  Doc sep;
+  doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+  // print return type
+  doc << " -> " << Print(signature->ret_type);
+  // print attr
+  Doc attr_doc;
+  std::vector<Doc> attr_docs;
+  for (const auto& it : op->attrs->dict) {
+    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  }
+  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
+  doc << Doc::Indent(2, attr_doc);
+  // print all the buffers in the tree
+  Doc buffer_doc;
+  std::vector<Doc> buffer_docs;
+  for (const auto& it : memo_buf_) {
+    const auto& buf = it.first;
+    buffer_docs.push_back(Print(buf)
+                          << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+                          << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+                          << Print(buf->strides));
+    if (!is_zero(buf->elem_offset)) {
+      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+    }
+    if (buf->scope != "global") {
+      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+    }
+    if (buf->data_alignment != 128) {
+      buffer_docs.back() << ", align=" << buf->data_alignment;
+    }
+    if (buf->offset_factor != 1) {
+      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+    }
+    if (buf->buffer_type != 1) {
+      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+    }
+    buffer_docs.back() << ")";
+  }
+  buffer_doc << Doc::NewLine() << "buffers = {";
+  buffer_doc << PrintSep(buffer_docs, Doc::Indent(9, Doc::Text(",") << Doc::NewLine()));
+  doc << Doc::Indent(2, buffer_doc) << "}";
+  // print buffer_map
+  std::vector<Doc> buffer_map_doc;
+  for (const auto& it : op->buffer_map) {
+    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+  }
+  doc << Doc::Indent(2, Doc::NewLine()
+      << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+  const auto* op = module.operator->();
+  Doc doc;
+
+  Doc body;
+  body << Doc::NewLine();
+  std::vector<Doc> functions;
+  for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+    if ((*it).second.as<PrimFuncNode>()) {
+      functions.push_back(Print((*it).second));
+    }
+  }
+  body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
+  body << Doc::NewLine() << DumpMeta();
+  doc << Doc::Indent(0, body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+  Doc doc;
+  doc << '[';
+  for (size_t i = 0; i < op->data.size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->data[i]);
+  }
+  doc << ']';
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+  Doc doc;
+  doc << "IterVar(" << Print(op->var) << ", [" << Print(op->dom) << "], "
+      << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "
+      << Doc::StrLiteral(op->thread_tag) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+  return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+  const Buffer& buffer = GetRef<Buffer>(op);
+  CHECK_GT(memo_buf_.count(buffer), 0);
+  return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+  return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+  return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+  return PrintConstScalar<int64_t>(op->dtype, &(op->value));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+  return PrintConstScalar<double>(op->dtype, &(op->value));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+  Doc doc;
+  doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+  const Var& var = GetRef<Var>(op);
+  return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString)     \
+  Doc TIRTextPrinter::VisitExpr_(const OpName* op) {               \
+    Doc doc;                                                       \
+    doc << '(' << Print(op->a) << OpString << Print(op->b) << ")"; \
+    return doc;                                                    \
+  }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " and ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " or ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+  Doc doc;
+  doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+  Doc doc;
+  doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+  Doc doc;
+  doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+  Doc doc;
+  doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+  Doc doc;
+  doc << "!" << Print(op->a);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+  Doc doc;
+  doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
+      << Print(op->false_value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+  Doc doc;
+  doc << "load(" << PrintDType(op->dtype) << ", "
+      << Print(op->buffer_var) << "[" << Print(op->index) << "])";
+  if (!is_one(op->predicate)) {
+    doc << "if " << Print(op->predicate);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
+  Doc doc;
+  doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
+  Doc doc;
+  doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+  Doc doc;
+  doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
+  return doc;
+}
+
+inline const char* CallType2String(CallNode::CallType t) {
+  switch (t) {
+    case CallNode::Extern:return "extern";
+    case CallNode::ExternCPlusPlus:return "extern_cpp";
+    case CallNode::PureExtern:return "pure_extern";
+    case CallNode::Halide:return "halide";
+    case CallNode::Intrinsic:return "intrin";
+    case CallNode::PureIntrinsic:return "pure_intrin";
+  }
+  return "Unknown";

Review comment:
       fixed. But return is kept otherwise it will trigger a warning in compilation




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] xqdan commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
xqdan commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-625717683


   Binds in some pass function is not clean for round trip dump ir, how do we deal with it?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626096064


   I found the problem. It is due to the different behavior of gcc5.4 and gcc7.4
   ```c++
   doc << "(" << Print(op->a) << OpString << Print(op->b) << ")";
   ```
   
   gcc5.4 will execute Print(op->b) first.
   gcc7.4 will execute Print(op->a) first.
   
   I have no idea why gcc5.4 will do so since `<<` should be left associative.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622203710


   > Some comments:
   > 
   > 1. There are more than one spaces before the left brace in the allocation line
   >    ```
   >    allocate(B.local, float32, [64])  {
   >    ```
   > 2. Can we use the same rule for the allocation stmt as the one for attr? Allocation stmt now will bring extra indentation
   > 3. ```
   >     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
   >    ```
   >    
   >    
   >    It is strange to print `nullptr` here especially in square brackets. Perhaps we can use `IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")]` or `IterVar(blockIdx.z: int32, , "ThreadIndex", "blockIdx.z")]`
   > 4. Considering future parsing use, we must print the dtype for every const number. But we may use some shorthand for common dtype. e.g. `2f` for float32, `2h` for float16(half), direct`2` for int32 (for here most integer numbers in schedule are int32). But still, keep the complete form for every type. e.g. `int8(2)`, `float64(2)`(or may be `fp64(2)`) , also, `float32(2)` is legal as well.
   
   1. fixed
   2. fixed. But here we implicitly assume that `Allocate` and `Attr` will have at least one child. Otherwise, for such a scenario
   ```c++
   attr...;
   attr...;
   attr...;
   for...;
   ```
   We can not determine whether it is `attr|attr|attr|for` or `attr(attr)|attr|for` or `attr(attr)|attr(for)` or `attr(attr(attr))|for` or `attr(attr(attr(for)))`
   
   3. fixed


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418368333



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');

Review comment:
       Not sure whether we should substitute "." with "_". Probably remove this line if we have some consensus? @tqchen 
   
   Furthermore, I think we should have some regex to restrict the characters allowed in var's name.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418368333



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');

Review comment:
       Not sure whether we should substitute "." with "_". Probably remove this line if we have some consensus? @tqchen 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622071330


   - Allocate `float32x32` is not the actual data type. Perhaps we can just show allocate the flattened size since that is the semantics. 
   - eval is not necessary since they can be implied in a call. 
   - We might need to update call later to something like `@intrin.func_name(args)`
   - Perhaps we do not need to add a new nested block for allocate, think of multiple let in


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-625889950






----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626098731


   Both `release` and `debug` under gcc5.4 will show the same behavior. It is not likely that O2 causes this problem.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418844571



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));

Review comment:
       Something like void. I don’t think that happens very often though




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626165912


   Take another look please @tqchen 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-624997033


   Relay's type printing requires attr and relayExpr printing, so It is not straightforward to make type printing independent. Meanwhile, relay's attir printing overlaps PrimExpr printing in tir now. We'd better combine them in later PRs.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r421904071



##########
File path: python/tvm/tir/parser.py
##########
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+"""Namespace for Tensor-level IR"""
+
+from . import _ffi_api
+
+
+def astext(code):

Review comment:
       there is not special Node class atm. For now let us just add astext to stmt as well.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418378401



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);

Review comment:
       If `v` has been used, then it will be `v`. Otherwise, it will be `v_1`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-625889950


   I can not reproduce the errors in CI.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626110351


   clang6.0 also works. But I will fix it under gcc5.4


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418364252



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));

Review comment:
       Shall we consider printing something non-empty when `dtype.bits() == 0`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-624696569


   Going to merge it in two days


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418367759



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }

Review comment:
       `std::ostringstream` is not necessary in this case :-)
   
   ```suggestion
         while (name_alloc_map_.count(
           unique_prefix =
             prefix + "_" + std::to_string(++it->second)
         ) > 0);
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] Hzfengsy commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-621810131


   Some comments:
   1. There are more than one spaces before the left brace in the allocation line 
       ```
       allocate(B.local, float32, [64])  {
       ```
   
   2. Can we use the same rule for the allocation stmt as the one for attr? Allocation stmt now will bring extra indentation
   
   3. ```
       attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
       ```
       It is strange to print `nullptr` here especially in square brackets. Perhaps we can use `IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")]` or `IterVar(blockIdx.z: int32, , "ThreadIndex", "blockIdx.z")]`
   
   4. Considering future parsing use, we must print the dtype for every const number. But we may use some shorthand for common dtype. e.g. `2f` for float32, `2h` for float16(half), direct`2` for int32 (for here most integer numbers in schedule are int32). But still, keep the complete form for every type. e.g. `int8(2)`, `float64(2)`(or may be `fp64(2)`) , also, `float32(2)` is legal as well.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418367759



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }

Review comment:
       ```suggestion
         while (name_alloc_map_.count(
           unique_prefix =
             prefix + "_" + std::to_string(++it->second)
         ) > 0);
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418401170



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_var_[var] = val;
+    return val << ": " << Print(GetType(var));
+  }
+
+  Doc AllocBuf(const Buffer& buffer) {
+    const auto& it = memo_buf_.find(buffer);
+    if (it != memo_buf_.end()) {
+      return it->second;
+    }
+    std::string name = buffer->name;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "buf_" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_buf_[buffer] = val;
+    return val;
+  }
+
+  /*!
+   * \brief special method to render vectors of docs with a separator
+   * \param vec vector of docs
+   * \param sep separator
+   */
+  static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+    Doc seq;
+    if (vec.size() != 0) {
+      seq = vec[0];
+      for (size_t i = 1; i < vec.size(); i++) {
+        seq << sep << vec[i];
+      }
+    }
+    return seq;
+  }
+
+  /*!
+   * \brief dump meta info
+   * \return Doc with meta info
+   */
+  Doc DumpMeta() {
+    if (show_meta_) {
+      return Doc::Text("__tvm_meta__ = ")
+          << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
+    } else {
+      return Doc::Text("");
+    }
+  }
+
+  Doc PrintBody(const Stmt& body, bool indent = true) {
+    Doc doc;
+    if (body->IsInstance<SeqStmtNode>()) return Print(body);
+    doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+    return doc;
+  }
+};
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node.as<StmtNode>()) {

Review comment:
       fixed

##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }

Review comment:
       fixed

##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {

Review comment:
       fixed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622071330


   - Allocate `float32x32` is not the actual data type. Perhaps we can just show allocate the flattened size since that is the semantics. 
   - eval is not necessary since they can be implied in a call. 
   - We might need to update call later to something like `@intrin.func_name(args)`
   - Perhaps we do not need to add a new nested block for allocate


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] Hzfengsy commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-625818864


   > Binds in some pass function is not clean for round trip dump ir, how do we deal with it?
   
   Do you mean that `buffer_bind` is not necessary to print after the pass `storage_flatten`?
   Yes, the buffer no longer exists after the flatten pass. All we have is that Var(buffer->data) in Load and Store. But we still need a place to define those vars, and that's why we still have buffer_bind, where we define buffer as well as buffer->data, in low-level functions
   
   Furthermore, maybe one day we can use BufferLoad/ BufferStore after the flatten and use buffer from the beginning to the end.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418365372



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {

Review comment:
       Why not just do like this?
   
   ```suggestion
     static Doc PrintConstScalar(DataType dtype, const T& data) {
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622238350


   To avoid making this page too long, I will edit&update the examples for reference in the top if I change the format.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-625889950


   I can not reproduce the errors in CI now. The errors reported in CI look strange. I can't see why `str(tvm.tir.any(x < y, x > z))` will give `((x < y: int32) || (x: int32 > z: int32))`. I will try to fix it in docker. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] xqdan commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
xqdan commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626509149


   > > Binds in some pass function is not clean for round trip dump ir, how do we deal with it?
   > 
   > Do you mean that `buffer_bind` is not necessary to print after the pass `storage_flatten`?
   > Yes, the buffer no longer exists after the flatten pass. All we have is that Var(buffer->data) in Load and Store. But we still need a place to define those vars, and that's why we still have buffer_bind, where we define buffer as well as buffer->data, in low-level functions
   > 
   > Furthermore, maybe one day we can use BufferLoad/ BufferStore after the flatten and use buffer from the beginning to the end.
   
   https://github.com/apache/incubator-tvm/blob/master/src/driver/driver_api.cc#L139
   
   that's i'm asking, seems binds is embedding into ir at the beginning, looks good


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418401901



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_var_[var] = val;
+    return val << ": " << Print(GetType(var));
+  }
+
+  Doc AllocBuf(const Buffer& buffer) {
+    const auto& it = memo_buf_.find(buffer);
+    if (it != memo_buf_.end()) {
+      return it->second;
+    }
+    std::string name = buffer->name;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "buf_" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_buf_[buffer] = val;
+    return val;
+  }
+
+  /*!
+   * \brief special method to render vectors of docs with a separator
+   * \param vec vector of docs
+   * \param sep separator
+   */
+  static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+    Doc seq;
+    if (vec.size() != 0) {
+      seq = vec[0];
+      for (size_t i = 1; i < vec.size(); i++) {
+        seq << sep << vec[i];
+      }
+    }
+    return seq;
+  }
+
+  /*!
+   * \brief dump meta info
+   * \return Doc with meta info
+   */
+  Doc DumpMeta() {
+    if (show_meta_) {
+      return Doc::Text("__tvm_meta__ = ")
+          << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
+    } else {
+      return Doc::Text("");
+    }
+  }
+
+  Doc PrintBody(const Stmt& body, bool indent = true) {
+    Doc doc;
+    if (body->IsInstance<SeqStmtNode>()) return Print(body);
+    doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+    return doc;
+  }
+};
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node.as<StmtNode>()) {
+    return VisitStmt(Downcast<Stmt>(node));
+  } else if (node.as<PrimExprNode>()) {
+    return VisitExpr(Downcast<PrimExpr>(node));
+  } else if (node.as<TypeNode>()) {
+    return VisitType(Downcast<Type>(node));
+  } else if (node.as<PrimFuncNode>()) {
+    return PrintPrimFunc(Downcast<PrimFunc>(node));
+  } else if (node.as<IRModuleNode>()) {
+    return PrintIRModule(Downcast<IRModule>(node));
+  } else if (node.as<ArrayNode>()) {
+    return PrintArray(node.as<ArrayNode>());
+  } else if (node.as<IterVarNode>()) {
+    return PrintIterVar(node.as<IterVarNode>());
+  } else if (node.as<RangeNode>()) {
+    return PrintRange(node.as<RangeNode>());
+  } else if (node.as<BufferNode>()) {
+    return PrintBuffer(node.as<BufferNode>());
+  } else if (node.as<StringObj>()) {
+    return PrintString(node.as<StringObj>());
+  } else {
+    return this->meta_.GetMetaNode(node);
+  }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+  const auto* op = primFunc.operator->();
+  const auto& signature = op->func_type_annotation();
+  // collect Meta in DictAttr
+  for (const auto& it : primFunc->attrs->dict) {
+    meta_collector_.Collect(it.second);
+  }
+  // collect buffers in buffer_map
+  memo_var_.clear();
+  memo_buf_.clear();
+  for (const auto& it : op->buffer_map) {
+    memo_buf_[it.second] = AllocBuf(it.second);
+  }
+  // print PrimFunc
+  Doc doc;
+  doc << "primfn" << "(";
+  // print params and its type annotation
+  std::vector<Doc> params;
+  for (const auto& param : op->params) {
+    params.push_back(Print(param));
+  }
+  Doc sep;
+  doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+  // print return type
+  doc << " -> " << Print(signature->ret_type);
+  // print attr
+  Doc attr_doc;
+  std::vector<Doc> attr_docs;
+  for (const auto& it : op->attrs->dict) {
+    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  }
+  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
+  doc << Doc::Indent(2, attr_doc);
+  // print all the buffers in the tree
+  Doc buffer_doc;
+  std::vector<Doc> buffer_docs;
+  for (const auto& it : memo_buf_) {
+    const auto& buf = it.first;
+    buffer_docs.push_back(Print(buf)
+                          << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+                          << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+                          << Print(buf->strides));
+    if (!is_zero(buf->elem_offset)) {
+      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+    }
+    if (buf->scope != "global") {
+      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+    }
+    if (buf->data_alignment != 128) {
+      buffer_docs.back() << ", align=" << buf->data_alignment;
+    }
+    if (buf->offset_factor != 1) {
+      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+    }
+    if (buf->buffer_type != 1) {
+      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+    }
+    buffer_docs.back() << ")";
+  }
+  buffer_doc << Doc::NewLine() << "buffers = {";
+  buffer_doc << PrintSep(buffer_docs, Doc::Indent(9, Doc::Text(",") << Doc::NewLine()));
+  doc << Doc::Indent(2, buffer_doc) << "}";
+  // print buffer_map
+  std::vector<Doc> buffer_map_doc;
+  for (const auto& it : op->buffer_map) {
+    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+  }
+  doc << Doc::Indent(2, Doc::NewLine()
+      << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+  const auto* op = module.operator->();
+  Doc doc;
+
+  Doc body;
+  body << Doc::NewLine();
+  std::vector<Doc> functions;
+  for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+    if ((*it).second.as<PrimFuncNode>()) {
+      functions.push_back(Print((*it).second));
+    }
+  }
+  body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
+  body << Doc::NewLine() << DumpMeta();
+  doc << Doc::Indent(0, body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+  Doc doc;
+  doc << '[';
+  for (size_t i = 0; i < op->data.size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->data[i]);
+  }
+  doc << ']';
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+  Doc doc;
+  doc << "IterVar(" << Print(op->var) << ", [" << Print(op->dom) << "], "
+      << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "
+      << Doc::StrLiteral(op->thread_tag) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+  return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+  const Buffer& buffer = GetRef<Buffer>(op);
+  CHECK_GT(memo_buf_.count(buffer), 0);
+  return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+  return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+  return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+  return PrintConstScalar<int64_t>(op->dtype, &(op->value));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+  return PrintConstScalar<double>(op->dtype, &(op->value));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+  Doc doc;
+  doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+  const Var& var = GetRef<Var>(op);
+  return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString)     \
+  Doc TIRTextPrinter::VisitExpr_(const OpName* op) {               \
+    Doc doc;                                                       \
+    doc << '(' << Print(op->a) << OpString << Print(op->b) << ")"; \
+    return doc;                                                    \
+  }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " and ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " or ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+  Doc doc;
+  doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+  Doc doc;
+  doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+  Doc doc;
+  doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+  Doc doc;
+  doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+  Doc doc;
+  doc << "!" << Print(op->a);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+  Doc doc;
+  doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
+      << Print(op->false_value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+  Doc doc;
+  doc << "load(" << PrintDType(op->dtype) << ", "
+      << Print(op->buffer_var) << "[" << Print(op->index) << "])";
+  if (!is_one(op->predicate)) {
+    doc << "if " << Print(op->predicate);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
+  Doc doc;
+  doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
+  Doc doc;
+  doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+  Doc doc;
+  doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
+  return doc;
+}
+
+inline const char* CallType2String(CallNode::CallType t) {
+  switch (t) {
+    case CallNode::Extern:return "extern";
+    case CallNode::ExternCPlusPlus:return "extern_cpp";
+    case CallNode::PureExtern:return "pure_extern";
+    case CallNode::Halide:return "halide";
+    case CallNode::Intrinsic:return "intrin";
+    case CallNode::PureIntrinsic:return "pure_intrin";
+  }
+  return "Unknown";

Review comment:
       Thanks! You may add “throw;” to avoid using return :-)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622203710


   > Some comments:
   > 
   > 1. There are more than one spaces before the left brace in the allocation line
   >    ```
   >    allocate(B.local, float32, [64])  {
   >    ```
   > 2. Can we use the same rule for the allocation stmt as the one for attr? Allocation stmt now will bring extra indentation
   > 3. ```
   >     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
   >    ```
   >    
   >    
   >    It is strange to print `nullptr` here especially in square brackets. Perhaps we can use `IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")]` or `IterVar(blockIdx.z: int32, , "ThreadIndex", "blockIdx.z")]`
   > 4. Considering future parsing use, we must print the dtype for every const number. But we may use some shorthand for common dtype. e.g. `2f` for float32, `2h` for float16(half), direct`2` for int32 (for here most integer numbers in schedule are int32). But still, keep the complete form for every type. e.g. `int8(2)`, `float64(2)`(or may be `fp64(2)`) , also, `float32(2)` is legal as well.
   
   1. fixed
   2. fixed. But here we implicitly assume that `Allocate` and `Attr` will have at least one child. Otherwise, for such a scenario
   ```c++
   attr...;
   attr...;
   attr...;
   for...
   ```
   We can not determine whether it is `attr|attr|attr|for` or `attr(attr)|attr|for` or `attr(attr)|attr(for)` or `attr(attr(attr))|for` or `attr(attr(attr(for)))`
   
   3. fixed


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418406985



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));

Review comment:
       When will we encounter a dtype with bits=0?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] roastduck commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
roastduck commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r426130269



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,597 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+#include "text_printer.h"
+
+namespace tvm {
+namespace tir {
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node->IsInstance<StmtNode>()) {
+    return VisitStmt(Downcast<Stmt>(node));
+  } else if (node->IsInstance<AnyNode>()) {
+    return Doc::Text("?");
+  } else if (node->IsInstance<PrimExprNode>()) {
+    return VisitExpr(Downcast<PrimExpr>(node));
+  } else if (node->IsInstance<TypeNode>()) {
+    return VisitType(Downcast<Type>(node));
+  } else if (node->IsInstance<PrimFuncNode>()) {
+    return PrintPrimFunc(Downcast<PrimFunc>(node));
+  } else if (node->IsInstance<IRModuleNode>()) {
+    return PrintIRModule(Downcast<IRModule>(node));
+  } else if (node->IsInstance<ArrayNode>()) {
+    return PrintArray(node.as<ArrayNode>());
+  } else if (node->IsInstance<IterVarNode>()) {
+    return PrintIterVar(node.as<IterVarNode>());
+  } else if (node->IsInstance<RangeNode>()) {
+    return PrintRange(node.as<RangeNode>());
+  } else if (node->IsInstance<BufferNode>()) {
+    return PrintBuffer(node.as<BufferNode>());
+  } else if (node->IsInstance<StringObj>()) {
+    return PrintString(node.as<StringObj>());
+  } else {
+    return this->meta_->GetMetaNode(node);
+  }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+  const auto* op = primFunc.operator->();
+  const auto& signature = op->func_type_annotation();
+  // collect Meta in DictAttr
+  for (const auto& it : primFunc->attrs->dict) {
+    meta_collector_.Collect(it.second);
+  }
+  // collect buffers in buffer_map
+  memo_var_.clear();
+  memo_buf_.clear();
+  for (const auto& it : op->buffer_map) {
+    memo_buf_[it.second] = AllocBuf(it.second);
+  }
+  // print PrimFunc
+  Doc doc;
+  doc << "primfn" << "(";
+  // print params and its type annotation
+  std::vector<Doc> params;
+  for (const auto& param : op->params) {
+    params.push_back(Print(param));
+  }
+  Doc sep;
+  doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+  // print return type
+  doc << " -> " << Print(signature->ret_type);
+  // print attr
+  Doc attr_doc;
+  std::vector<Doc> attr_docs;
+  for (const auto& it : op->attrs->dict) {
+    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  }
+  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
+  doc << Doc::Indent(2, attr_doc);
+  // print all the buffers in the tree
+  Doc buffer_doc;
+  std::vector<Doc> buffer_docs;
+  for (const auto& it : memo_buf_) {
+    const auto& buf = it.first;
+    buffer_docs.push_back(Print(buf)
+                              << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+                              << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+                              << Print(buf->strides));
+    if (!is_zero(buf->elem_offset)) {
+      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+    }
+    if (buf->scope != "global") {
+      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+    }
+    if (buf->data_alignment != 128) {
+      buffer_docs.back() << ", align=" << buf->data_alignment;
+    }
+    if (buf->offset_factor != 1) {
+      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+    }
+    if (buf->buffer_type != 1) {
+      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+    }
+    buffer_docs.back() << ")";
+  }
+  buffer_doc << Doc::NewLine() << "buffers = {";
+  buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine()));
+  doc << Doc::Indent(2, buffer_doc) << "}";
+  // print buffer_map
+  std::vector<Doc> buffer_map_doc;
+  for (const auto& it : op->buffer_map) {
+    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+  }
+  doc << Doc::Indent(2, Doc::NewLine()
+      << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+  const auto* op = module.operator->();
+  Doc doc;
+
+  Doc body;
+  body << Doc::NewLine();
+  std::vector<Doc> functions;
+  for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+    if ((*it).second.as<PrimFuncNode>()) {
+      functions.push_back(Print((*it).second));
+    }
+  }
+  body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
+  doc << Doc::Indent(0, body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+  Doc doc;
+  doc << '[';
+  for (size_t i = 0; i < op->data.size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->data[i]);
+  }
+  doc << ']';
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+  Doc doc;
+  doc << "IterVar(" << Print(op->var);
+  if (op->dom.defined()) {
+    doc << ", [" << Print(op->dom) << "], ";
+  } else {
+    doc << ", " << Print(op->dom) << ", ";
+  }
+  doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
+  doc << Doc::StrLiteral(op->thread_tag) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+  return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+  const Buffer& buffer = GetRef<Buffer>(op);
+  CHECK_GT(memo_buf_.count(buffer), 0);
+  return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+  return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+  return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+  return PrintConstScalar<int64_t>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+  return PrintConstScalar<double>(op->dtype, op->value);
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+  Doc doc;
+  doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+  const Var& var = GetRef<Var>(op);
+  return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString)     \
+  Doc TIRTextPrinter::VisitExpr_(const OpName* op) {               \
+    Doc doc;                                                       \
+    doc << "(" << Print(op->a) << OpString;                        \
+    doc << Print(op->b) << ")";                                    \
+    return doc;                                                    \
+  }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " && ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " || ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+  Doc doc;
+  doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+  Doc doc;
+  doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+  Doc doc;
+  doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+  Doc doc;
+  doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+  Doc doc;
+  doc << "!" << Print(op->a);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+  Doc doc;
+  doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
+      << Print(op->false_value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+  Doc doc;
+  doc << "(" << PrintDType(op->dtype) << "*)"
+      << Print(op->buffer_var) << "[" << Print(op->index) << "])";

Review comment:
       Parentheses are not matched here. Please consider `((dtype*)buffer_var)[index]`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418370164



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);

Review comment:
       Just to clarify, what's the result of `GetUniqueName(name)` if `name.length() == 0` for the first time, `v` or `v0`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r421901168



##########
File path: python/tvm/tir/parser.py
##########
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+"""Namespace for Tensor-level IR"""
+
+from . import _ffi_api
+
+
+def astext(code):

Review comment:
       It is a subclass of Object in tir/stmt.py




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] junrushao1994 commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r418371862



##########
File path: src/printer/tir_text_printer.cc
##########
@@ -0,0 +1,735 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file printer/tir_text_printer.cc
+ * \brief Printer to print out the IR text format
+ *        that can be parsed by a parser.
+ */
+
+#include <tvm/ir/module.h>
+#include <tvm/ir/type.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/node/serialization.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/expr_functor.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <string>
+
+#include "doc.h"
+#include "meta_data.h"
+
+namespace tvm {
+namespace tir {
+
+/*!
+ *  \brief Meta node collector
+ *  If we decide to put some node into meta, then all the sub-nodes inside
+ *  it need to be put in meta as well, since when parsing we need to know
+ *  whether two refs are the same
+ */
+class MetaCollector : public StmtExprVisitor {
+ public:
+  explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
+
+  void Collect(const ObjectRef& n) {
+    // these nodes can be print directly(StringLiteral or use identifier to identify)
+    if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>()
+        || n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
+      return;
+    }
+    if (n->IsInstance<StmtNode>()) {
+      VisitStmt(Downcast<Stmt>(n));
+    } else if (n->IsInstance<PrimExprNode>()) {
+      VisitExpr(Downcast<PrimExpr>(n));
+    }
+  }
+
+  void VisitStmt(const Stmt& n) override {
+    meta_->GetMetaNode(n);
+    StmtVisitor::VisitStmt(n);
+  }
+
+  void VisitExpr(const PrimExpr& n) override {
+    meta_->GetMetaNode(n);
+    ExprVisitor::VisitExpr(n);
+  }
+
+ private:
+  TextMetaDataContext* meta_;
+};
+
+class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
+                       public ExprFunctor<Doc(const PrimExpr&)>,
+                       public TypeFunctor<Doc(const Type&)> {
+ public:
+  explicit TIRTextPrinter(bool show_meta) : show_meta_(show_meta), meta_collector_(&meta_) {}
+
+  /*! \brief Print the node */
+  Doc Print(const ObjectRef& node);
+
+ private:
+  /*! \brief whether show meta data */
+  bool show_meta_;
+  /*! \brief meta data context */
+  TextMetaDataContext meta_;
+  /*! \brief meta collector */
+  MetaCollector meta_collector_;
+  /*! \brief Map from Var to Doc */
+  std::unordered_map<Var, Doc, ObjectHash, ObjectEqual> memo_var_;
+  /*! \brief Map from Buffer to Doc */
+  std::unordered_map<Buffer, Doc, ObjectHash, ObjectEqual> memo_buf_;
+  /*! \brief name allocation map */
+  std::unordered_map<std::string, int> name_alloc_map_;
+
+  Doc VisitExpr_(const IntImmNode* op) override;
+  Doc VisitExpr_(const FloatImmNode* op) override;
+  Doc VisitExpr_(const StringImmNode* op) override;
+  Doc VisitExpr_(const CastNode* op) override;
+  Doc VisitExpr_(const VarNode* op) override;
+  Doc VisitExpr_(const AddNode* op) override;
+  Doc VisitExpr_(const SubNode* op) override;
+  Doc VisitExpr_(const MulNode* op) override;
+  Doc VisitExpr_(const DivNode* op) override;
+  Doc VisitExpr_(const ModNode* op) override;
+  Doc VisitExpr_(const FloorDivNode* op) override;
+  Doc VisitExpr_(const FloorModNode* op) override;
+  Doc VisitExpr_(const MinNode* op) override;
+  Doc VisitExpr_(const MaxNode* op) override;
+  Doc VisitExpr_(const EQNode* op) override;
+  Doc VisitExpr_(const NENode* op) override;
+  Doc VisitExpr_(const LTNode* op) override;
+  Doc VisitExpr_(const LENode* op) override;
+  Doc VisitExpr_(const GTNode* op) override;
+  Doc VisitExpr_(const GENode* op) override;
+  Doc VisitExpr_(const AndNode* op) override;
+  Doc VisitExpr_(const OrNode* op) override;
+  Doc VisitExpr_(const NotNode* op) override;
+  Doc VisitExpr_(const SelectNode* op) override;
+  Doc VisitExpr_(const BufferLoadNode* op) override;
+  Doc VisitExpr_(const LoadNode* op) override;
+  Doc VisitExpr_(const RampNode* op) override;
+  Doc VisitExpr_(const BroadcastNode* op) override;
+  Doc VisitExpr_(const LetNode* op) override;
+  Doc VisitExpr_(const CallNode* op) override;
+  Doc VisitExpr_(const ShuffleNode* op) override;
+  Doc VisitExpr_(const ReduceNode* op) override;
+  Doc VisitExprDefault_(const Object* op) override;
+
+  Doc VisitStmt_(const LetStmtNode* op) override;
+  Doc VisitStmt_(const AttrStmtNode* op) override;
+  Doc VisitStmt_(const AssertStmtNode* op) override;
+  Doc VisitStmt_(const StoreNode* op) override;
+  Doc VisitStmt_(const BufferStoreNode* op) override;
+  Doc VisitStmt_(const BufferRealizeNode* op) override;
+  Doc VisitStmt_(const AllocateNode* op) override;
+  Doc VisitStmt_(const FreeNode* op) override;
+  Doc VisitStmt_(const IfThenElseNode* op) override;
+  Doc VisitStmt_(const SeqStmtNode* op) override;
+  Doc VisitStmt_(const EvaluateNode* op) override;
+  Doc VisitStmt_(const ForNode* op) override;
+  Doc VisitStmt_(const PrefetchNode* op) override;
+  Doc VisitStmtDefault_(const Object* op) override;
+
+  Doc VisitType_(const PrimTypeNode* node) override;
+  Doc VisitType_(const PointerTypeNode* node) override;
+  Doc VisitType_(const TupleTypeNode* node) override;
+
+  Doc PrintIRModule(const IRModule& module);
+  Doc PrintPrimFunc(const PrimFunc& primFunc);
+  Doc PrintArray(const ArrayNode* op);
+  Doc PrintIterVar(const IterVarNode* op);
+  Doc PrintRange(const RangeNode* op);
+  Doc PrintBuffer(const BufferNode* op);
+  Doc PrintString(const StringObj* op) {
+    return Doc::StrLiteral(op->data);
+  }
+
+  /*!
+   * \brief special method to print out data type
+   * \param dtype The data type
+   */
+  static Doc PrintDType(DataType dtype) {
+    return Doc::Text(runtime::DLDataType2String(dtype));
+  }
+
+  /*!
+   * \brief special method to print out const scalar
+   * \param dtype The data type
+   * \param data The pointer to hold the data.
+   */
+  template <typename T>
+  static Doc PrintConstScalar(DataType dtype, const T* data) {
+    Doc doc;
+    std::ostringstream os;
+    os << data[0];
+    if (dtype == DataType::Int(32)) {
+      doc << Doc::Text(os.str());
+    } else {
+      doc << PrintDType(dtype) << "(" << Doc::Text(os.str()) << ")";
+    }
+    return doc;
+  }
+
+  Doc GetUniqueName(std::string prefix) {
+    // std::replace(prefix.begin(), prefix.end(), '.', '_');
+    std::string unique_prefix = prefix;
+    auto it = name_alloc_map_.find(prefix);
+    if (it != name_alloc_map_.end()) {
+      while (true) {
+        std::ostringstream os;
+        os << prefix << "_" << (++it->second);
+        std::string name = os.str();
+        if (name_alloc_map_.count(name) == 0) {
+          unique_prefix = name;
+          break;
+        }
+      }
+    }
+    name_alloc_map_[unique_prefix] = 0;
+    return Doc::Text(unique_prefix);
+  }
+
+  Doc AllocVar(const Var& var) {
+    const auto& it = memo_var_.find(var);
+    if (it != memo_var_.end()) {
+      return it->second;
+    }
+    std::string name = var->name_hint;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "v" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_var_[var] = val;
+    return val << ": " << Print(GetType(var));
+  }
+
+  Doc AllocBuf(const Buffer& buffer) {
+    const auto& it = memo_buf_.find(buffer);
+    if (it != memo_buf_.end()) {
+      return it->second;
+    }
+    std::string name = buffer->name;
+    if (name.length() == 0 || !std::isalpha(name[0])) {
+      name = "buf_" + name;
+    }
+    Doc val = GetUniqueName(name);
+    memo_buf_[buffer] = val;
+    return val;
+  }
+
+  /*!
+   * \brief special method to render vectors of docs with a separator
+   * \param vec vector of docs
+   * \param sep separator
+   */
+  static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
+    Doc seq;
+    if (vec.size() != 0) {
+      seq = vec[0];
+      for (size_t i = 1; i < vec.size(); i++) {
+        seq << sep << vec[i];
+      }
+    }
+    return seq;
+  }
+
+  /*!
+   * \brief dump meta info
+   * \return Doc with meta info
+   */
+  Doc DumpMeta() {
+    if (show_meta_) {
+      return Doc::Text("__tvm_meta__ = ")
+          << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
+    } else {
+      return Doc::Text("");
+    }
+  }
+
+  Doc PrintBody(const Stmt& body, bool indent = true) {
+    Doc doc;
+    if (body->IsInstance<SeqStmtNode>()) return Print(body);
+    doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}";
+    return doc;
+  }
+};
+
+Doc TIRTextPrinter::Print(const ObjectRef& node) {
+  if (!node.defined()) return Doc::Text("(nullptr)");
+  if (node.as<StmtNode>()) {
+    return VisitStmt(Downcast<Stmt>(node));
+  } else if (node.as<PrimExprNode>()) {
+    return VisitExpr(Downcast<PrimExpr>(node));
+  } else if (node.as<TypeNode>()) {
+    return VisitType(Downcast<Type>(node));
+  } else if (node.as<PrimFuncNode>()) {
+    return PrintPrimFunc(Downcast<PrimFunc>(node));
+  } else if (node.as<IRModuleNode>()) {
+    return PrintIRModule(Downcast<IRModule>(node));
+  } else if (node.as<ArrayNode>()) {
+    return PrintArray(node.as<ArrayNode>());
+  } else if (node.as<IterVarNode>()) {
+    return PrintIterVar(node.as<IterVarNode>());
+  } else if (node.as<RangeNode>()) {
+    return PrintRange(node.as<RangeNode>());
+  } else if (node.as<BufferNode>()) {
+    return PrintBuffer(node.as<BufferNode>());
+  } else if (node.as<StringObj>()) {
+    return PrintString(node.as<StringObj>());
+  } else {
+    return this->meta_.GetMetaNode(node);
+  }
+}
+
+Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
+  const auto* op = primFunc.operator->();
+  const auto& signature = op->func_type_annotation();
+  // collect Meta in DictAttr
+  for (const auto& it : primFunc->attrs->dict) {
+    meta_collector_.Collect(it.second);
+  }
+  // collect buffers in buffer_map
+  memo_var_.clear();
+  memo_buf_.clear();
+  for (const auto& it : op->buffer_map) {
+    memo_buf_[it.second] = AllocBuf(it.second);
+  }
+  // print PrimFunc
+  Doc doc;
+  doc << "primfn" << "(";
+  // print params and its type annotation
+  std::vector<Doc> params;
+  for (const auto& param : op->params) {
+    params.push_back(Print(param));
+  }
+  Doc sep;
+  doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
+  // print return type
+  doc << " -> " << Print(signature->ret_type);
+  // print attr
+  Doc attr_doc;
+  std::vector<Doc> attr_docs;
+  for (const auto& it : op->attrs->dict) {
+    attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
+  }
+  attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
+  doc << Doc::Indent(2, attr_doc);
+  // print all the buffers in the tree
+  Doc buffer_doc;
+  std::vector<Doc> buffer_docs;
+  for (const auto& it : memo_buf_) {
+    const auto& buf = it.first;
+    buffer_docs.push_back(Print(buf)
+                          << Doc::Text(": Buffer(") << Print(buf->data) << ", "
+                          << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", "
+                          << Print(buf->strides));
+    if (!is_zero(buf->elem_offset)) {
+      buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset);
+    }
+    if (buf->scope != "global") {
+      buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope);
+    }
+    if (buf->data_alignment != 128) {
+      buffer_docs.back() << ", align=" << buf->data_alignment;
+    }
+    if (buf->offset_factor != 1) {
+      buffer_docs.back() << ", offset_factor=" << buf->offset_factor;
+    }
+    if (buf->buffer_type != 1) {
+      buffer_docs.back() << ", type=" << Doc::StrLiteral("auto");
+    }
+    buffer_docs.back() << ")";
+  }
+  buffer_doc << Doc::NewLine() << "buffers = {";
+  buffer_doc << PrintSep(buffer_docs, Doc::Indent(9, Doc::Text(",") << Doc::NewLine()));
+  doc << Doc::Indent(2, buffer_doc) << "}";
+  // print buffer_map
+  std::vector<Doc> buffer_map_doc;
+  for (const auto& it : op->buffer_map) {
+    buffer_map_doc.push_back(Print(it.first) << ": " << Print(it.second));
+  }
+  doc << Doc::Indent(2, Doc::NewLine()
+      << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
+  doc << PrintBody(op->body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
+  const auto* op = module.operator->();
+  Doc doc;
+
+  Doc body;
+  body << Doc::NewLine();
+  std::vector<Doc> functions;
+  for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
+    if ((*it).second.as<PrimFuncNode>()) {
+      functions.push_back(Print((*it).second));
+    }
+  }
+  body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
+  body << Doc::NewLine() << DumpMeta();
+  doc << Doc::Indent(0, body);
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
+  Doc doc;
+  doc << '[';
+  for (size_t i = 0; i < op->data.size(); ++i) {
+    if (i != 0) {
+      doc << ", ";
+    }
+    doc << Print(op->data[i]);
+  }
+  doc << ']';
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
+  Doc doc;
+  doc << "IterVar(" << Print(op->var) << ", [" << Print(op->dom) << "], "
+      << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", "
+      << Doc::StrLiteral(op->thread_tag) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
+  return Print(op->min) << ":" << Print(op->min + op->extent);
+}
+
+Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
+  const Buffer& buffer = GetRef<Buffer>(op);
+  CHECK_GT(memo_buf_.count(buffer), 0);
+  return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : memo_buf_[buffer];
+}
+
+Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
+  return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
+  return this->meta_.GetMetaNode(GetRef<ObjectRef>(op));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
+  return PrintConstScalar<int64_t>(op->dtype, &(op->value));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
+  return PrintConstScalar<double>(op->dtype, &(op->value));
+}
+
+Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
+
+Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
+  Doc doc;
+  doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const VarNode* op) {
+  const Var& var = GetRef<Var>(op);
+  return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
+}
+
+#define TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OpName, OpString)     \
+  Doc TIRTextPrinter::VisitExpr_(const OpName* op) {               \
+    Doc doc;                                                       \
+    doc << '(' << Print(op->a) << OpString << Print(op->b) << ")"; \
+    return doc;                                                    \
+  }
+
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AddNode, " + ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(SubNode, " - ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(MulNode, "*")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(DivNode, " / ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(ModNode, " % ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(EQNode, " == ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(NENode, " != ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LTNode, " < ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(LENode, " <= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GTNode, " > ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(GENode, " >= ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(AndNode, " and ")
+TVM_DECLARE_TIR_HYBRID_PRINTER_BINOP(OrNode, " or ")
+
+Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
+  Doc doc;
+  doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
+  Doc doc;
+  doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
+  Doc doc;
+  doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
+  Doc doc;
+  doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
+  Doc doc;
+  doc << "!" << Print(op->a);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
+  Doc doc;
+  doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
+      << Print(op->false_value);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
+  Doc doc;
+  doc << Print(op->buffer) << Print(op->indices);
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
+  Doc doc;
+  doc << "load(" << PrintDType(op->dtype) << ", "
+      << Print(op->buffer_var) << "[" << Print(op->index) << "])";
+  if (!is_one(op->predicate)) {
+    doc << "if " << Print(op->predicate);
+  }
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
+  Doc doc;
+  doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
+  Doc doc;
+  doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
+  return doc;
+}
+
+Doc TIRTextPrinter::VisitExpr_(const LetNode* op) {
+  Doc doc;
+  doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
+  return doc;
+}
+
+inline const char* CallType2String(CallNode::CallType t) {
+  switch (t) {
+    case CallNode::Extern:return "extern";
+    case CallNode::ExternCPlusPlus:return "extern_cpp";
+    case CallNode::PureExtern:return "pure_extern";
+    case CallNode::Halide:return "halide";
+    case CallNode::Intrinsic:return "intrin";
+    case CallNode::PureIntrinsic:return "pure_intrin";
+  }
+  return "Unknown";

Review comment:
       It might be better to just `LOG(FATAL)` in this case?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622203710


   > Some comments:
   > 
   > 1. There are more than one spaces before the left brace in the allocation line
   >    ```
   >    allocate(B.local, float32, [64])  {
   >    ```
   > 2. Can we use the same rule for the allocation stmt as the one for attr? Allocation stmt now will bring extra indentation
   > 3. ```
   >     attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
   >    ```
   >    
   >    
   >    It is strange to print `nullptr` here especially in square brackets. Perhaps we can use `IterVar(blockIdx.z: int32, (nullptr), "ThreadIndex", "blockIdx.z")]` or `IterVar(blockIdx.z: int32, , "ThreadIndex", "blockIdx.z")]`
   > 4. Considering future parsing use, we must print the dtype for every const number. But we may use some shorthand for common dtype. e.g. `2f` for float32, `2h` for float16(half), direct`2` for int32 (for here most integer numbers in schedule are int32). But still, keep the complete form for every type. e.g. `int8(2)`, `float64(2)`(or may be `fp64(2)`) , also, `float32(2)` is legal as well.
   
   1. fixed
   2. fixed. But here we implicitly assume that `Allocate` and `Attr` will have at least one child. Otherwise, for such a scenario
   ```c++
   attr...;
   attr...;
   attr...;
   ```
   We can not determine whether it is `attr|attr|attr` or `attr(attr)|attr`
   3. fixed


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-625889950


   I can not reproduce the errors in CI now. The errors reported in CI look strange. I will try to fix it. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-621600022


   cc @jroesch @antinucleon @junrushao1994 @Hzfengsy @mbrookhart 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on a change in pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen commented on a change in pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#discussion_r421561409



##########
File path: python/tvm/tir/parser.py
##########
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import, redefined-builtin
+"""Namespace for Tensor-level IR"""
+
+from . import _ffi_api
+
+
+def astext(code):

Review comment:
        Likely we should not put astext here. Because irnode.astext already redirects to astext in the backend? If we hack relay's astext to redirect. Alternatively overload the astext member function of the TIR Nodes (PrimExpr, PrimFunc).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626193743


   THanks @spectrometerHBH @xqdan @junrushao1994 @Hzfengsy !


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-622238350


   To avoid making this page too long, I will edit&update the examples for reference in the top comment if I change the format.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] spectrometerHBH edited a comment on pull request #5483: [TIR][Printer] text format printer considering future parsing use

Posted by GitBox <gi...@apache.org>.
spectrometerHBH edited a comment on pull request #5483:
URL: https://github.com/apache/incubator-tvm/pull/5483#issuecomment-626096064


   I found the problem. It is due to the different behavior of gcc5.4 and gcc7.4
   ```c++
   doc << "(" << Print(op->a) << OpString << Print(op->b) << ")";
   ```
   
   gcc5.4 will execute Print(op->b) first.
   gcc7.4 will execute Print(op->a) first.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org