You are viewing a plain text version of this content. The canonical link for it is here.
Posted to discuss-archive@tvm.apache.org by JiaNan Wang via Apache TVM Discuss <no...@discuss.tvm.ai> on 2022/03/22 02:12:02 UTC

[Apache TVM Discuss] [Questions] Can One reduce stage fuse into another reduce stage?


Hello,I am trying to fuse one layer convolution computation and their relu result into next layer convolution computation. I tried two methods, one is to use te.sum expression as a parameter of another te.sum, and the other is to use s.compute_inline(), but both fail. I would like to know if it is possible to combine two reduce stages (te.sum) into one reduce stage in te, if not, can relay and tir complete the expression of this function.here is current tir without fusion:

        primfn(args: handle, arg_type_ids: handle, num_args: int32, out_ret_value: handle, out_ret_tcode: handle, resource_handle: handle) -> int32`
      attr = {"target": meta[Target][0], "tir.noalias": True, "global_symbol": "myfunc_fusion", "from_legacy_te_schedule": True, "tir.is_entry_func": True, "calling_conv": 1} {
      assert((num_args == 4), "myfunc_fusion: num_args should be 4")
      let arg0: handle = @tir.tvm_struct_get(args, 0, 12, dtype=handle)
      let arg0.code: int32 = (int32*)arg_type_ids[0]
      let arg1: handle = @tir.tvm_struct_get(args, 1, 12, dtype=handle)
      let arg1.code: int32 = (int32*)arg_type_ids[1]
      let arg2: handle = @tir.tvm_struct_get(args, 2, 12, dtype=handle)
      let arg2.code: int32 = (int32*)arg_type_ids[2]
      let arg3: handle = @tir.tvm_struct_get(args, 3, 12, dtype=handle)
      let arg3.code: int32 = (int32*)arg_type_ids[3]
      let A: Pointer(float32) = @tir.tvm_struct_get(arg0, 0, 1, dtype=handle)
      attr [A] "storage_alignment" = 128;
      let arg0.shape: handle = @tir.tvm_struct_get(arg0, 0, 2, dtype=handle)
      let arg0.strides: handle = @tir.tvm_struct_get(arg0, 0, 3, dtype=handle)
      let dev_id: int32 = @tir.tvm_struct_get(arg0, 0, 9, dtype=int32)
      let W: Pointer(float32) = @tir.tvm_struct_get(arg1, 0, 1, dtype=handle)
      attr [W] "storage_alignment" = 128;
      let arg1.shape: handle = @tir.tvm_struct_get(arg1, 0, 2, dtype=handle)
      let arg1.strides: handle = @tir.tvm_struct_get(arg1, 0, 3, dtype=handle)
      let W_2: Pointer(float32) = @tir.tvm_struct_get(arg2, 0, 1, dtype=handle)
      attr [W_2] "storage_alignment" = 128;
      let arg2.shape: handle = @tir.tvm_struct_get(arg2, 0, 2, dtype=handle)
      let arg2.strides: handle = @tir.tvm_struct_get(arg2, 0, 3, dtype=handle)
      let C: Pointer(float32) = @tir.tvm_struct_get(arg3, 0, 1, dtype=handle)
      attr [C] "storage_alignment" = 128;
      let arg3.shape: handle = @tir.tvm_struct_get(arg3, 0, 2, dtype=handle)
      let arg3.strides: handle = @tir.tvm_struct_get(arg3, 0, 3, dtype=handle)
      assert(((((arg0.code == 3) || (arg0.code == 13)) || (arg0.code == 7)) || (arg0.code == 4)), "myfunc_fusion: Expect arg[0] to be pointer")
      assert(((((arg1.code == 3) || (arg1.code == 13)) || (arg1.code == 7)) || (arg1.code == 4)), "myfunc_fusion: Expect arg[1] to be pointer")
      assert(((((arg2.code == 3) || (arg2.code == 13)) || (arg2.code == 7)) || (arg2.code == 4)), "myfunc_fusion: Expect arg[2] to be pointer")
      assert(((((arg3.code == 3) || (arg3.code == 13)) || (arg3.code == 7)) || (arg3.code == 4)), "myfunc_fusion: Expect arg[3] to be pointer")
      assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is expected to equal 4")
      assert((4 == @tir.tvm_struct_get(arg0, 0, 4, dtype=int32)), "arg0.ndim is expected to equal 4")
      assert((((@tir.tvm_struct_get(arg0, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg0, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg0, 0, 7, dtype=uint16) == 1u16)), "arg0.dtype is expected to be float32")
      assert((56 == cast(int32, (int64*)arg0.shape[0])), "Argument arg0.shape[0] has an unsatisfied constraint: (56 == int32(arg0.shape[0]))")
      assert((56 == cast(int32, (int64*)arg0.shape[1])), "Argument arg0.shape[1] has an unsatisfied constraint: (56 == int32(arg0.shape[1]))")
      assert((64 == cast(int32, (int64*)arg0.shape[2])), "Argument arg0.shape[2] has an unsatisfied constraint: (64 == int32(arg0.shape[2]))")
      assert((3 == cast(int32, (int64*)arg0.shape[3])), "Argument arg0.shape[3] has an unsatisfied constraint: (3 == int32(arg0.shape[3]))")
       {
        if !@tir.isnullptr(arg0.strides, dtype=bool) {
          assert(((((1 == cast(int32, (int64*)arg0.strides[3])) && (3 == cast(int32, (int64*)arg0.strides[2]))) && (192 == cast(int32, (int64*)arg0.strides[1]))) && (10752 == cast(int32, (int64*)arg0.strides[0]))), "arg0.strides: expected to be compact array")
          0
        }
        assert((0u64 == @tir.tvm_struct_get(arg0, 0, 8, dtype=uint64)), "Argument arg0.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg0, 0, 8))")
        assert((1 == @tir.tvm_struct_get(arg0, 0, 10, dtype=int32)), "Argument arg0.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg0, 0, 10))")
        assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim is expected to equal 4")
        assert((4 == @tir.tvm_struct_get(arg1, 0, 4, dtype=int32)), "arg1.ndim is expected to equal 4")
        assert((((@tir.tvm_struct_get(arg1, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg1, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg1, 0, 7, dtype=uint16) == 1u16)), "arg1.dtype is expected to be float32")
        assert((3 == cast(int32, (int64*)arg1.shape[0])), "Argument arg1.shape[0] has an unsatisfied constraint: (3 == int32(arg1.shape[0]))")
        assert((3 == cast(int32, (int64*)arg1.shape[1])), "Argument arg1.shape[1] has an unsatisfied constraint: (3 == int32(arg1.shape[1]))")
        assert((64 == cast(int32, (int64*)arg1.shape[2])), "Argument arg1.shape[2] has an unsatisfied constraint: (64 == int32(arg1.shape[2]))")
        assert((64 == cast(int32, (int64*)arg1.shape[3])), "Argument arg1.shape[3] has an unsatisfied constraint: (64 == int32(arg1.shape[3]))")
         {
          if !@tir.isnullptr(arg1.strides, dtype=bool) {
            assert(((((1 == cast(int32, (int64*)arg1.strides[3])) && (64 == cast(int32, (int64*)arg1.strides[2]))) && (4096 == cast(int32, (int64*)arg1.strides[1]))) && (12288 == cast(int32, (int64*)arg1.strides[0]))), "arg1.strides: expected to be compact array")
            0
          }
          assert((0u64 == @tir.tvm_struct_get(arg1, 0, 8, dtype=uint64)), "Argument arg1.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg1, 0, 8))")
          assert((1 == @tir.tvm_struct_get(arg1, 0, 10, dtype=int32)), "Argument arg1.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg1, 0, 10))")
          assert((dev_id == @tir.tvm_struct_get(arg1, 0, 9, dtype=int32)), "Argument arg1.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg1, 0, 9))")
          assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), "arg2.ndim is expected to equal 4")
          assert((4 == @tir.tvm_struct_get(arg2, 0, 4, dtype=int32)), "arg2.ndim is expected to equal 4")
          assert((((@tir.tvm_struct_get(arg2, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg2, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg2, 0, 7, dtype=uint16) == 1u16)), "arg2.dtype is expected to be float32")
          assert((3 == cast(int32, (int64*)arg2.shape[0])), "Argument arg2.shape[0] has an unsatisfied constraint: (3 == int32(arg2.shape[0]))")
          assert((3 == cast(int32, (int64*)arg2.shape[1])), "Argument arg2.shape[1] has an unsatisfied constraint: (3 == int32(arg2.shape[1]))")
          assert((64 == cast(int32, (int64*)arg2.shape[2])), "Argument arg2.shape[2] has an unsatisfied constraint: (64 == int32(arg2.shape[2]))")
          assert((64 == cast(int32, (int64*)arg2.shape[3])), "Argument arg2.shape[3] has an unsatisfied constraint: (64 == int32(arg2.shape[3]))")
           {
            if !@tir.isnullptr(arg2.strides, dtype=bool) {
              assert(((((1 == cast(int32, (int64*)arg2.strides[3])) && (64 == cast(int32, (int64*)arg2.strides[2]))) && (4096 == cast(int32, (int64*)arg2.strides[1]))) && (12288 == cast(int32, (int64*)arg2.strides[0]))), "arg2.strides: expected to be compact array")
              0
            }
            assert((0u64 == @tir.tvm_struct_get(arg2, 0, 8, dtype=uint64)), "Argument arg2.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg2, 0, 8))")
            assert((1 == @tir.tvm_struct_get(arg2, 0, 10, dtype=int32)), "Argument arg2.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg2, 0, 10))")
            assert((dev_id == @tir.tvm_struct_get(arg2, 0, 9, dtype=int32)), "Argument arg2.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg2, 0, 9))")
            assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), "arg3.ndim is expected to equal 4")
            assert((4 == @tir.tvm_struct_get(arg3, 0, 4, dtype=int32)), "arg3.ndim is expected to equal 4")
            assert((((@tir.tvm_struct_get(arg3, 0, 5, dtype=uint8) == 2u8) && (@tir.tvm_struct_get(arg3, 0, 6, dtype=uint8) == 32u8)) && (@tir.tvm_struct_get(arg3, 0, 7, dtype=uint16) == 1u16)), "arg3.dtype is expected to be float32")
            assert((54 == cast(int32, (int64*)arg3.shape[0])), "Argument arg3.shape[0] has an unsatisfied constraint: (54 == int32(arg3.shape[0]))")
            assert((54 == cast(int32, (int64*)arg3.shape[1])), "Argument arg3.shape[1] has an unsatisfied constraint: (54 == int32(arg3.shape[1]))")
            assert((64 == cast(int32, (int64*)arg3.shape[2])), "Argument arg3.shape[2] has an unsatisfied constraint: (64 == int32(arg3.shape[2]))")
            assert((3 == cast(int32, (int64*)arg3.shape[3])), "Argument arg3.shape[3] has an unsatisfied constraint: (3 == int32(arg3.shape[3]))")
             {
              if !@tir.isnullptr(arg3.strides, dtype=bool) {
                assert(((((1 == cast(int32, (int64*)arg3.strides[3])) && (3 == cast(int32, (int64*)arg3.strides[2]))) && (192 == cast(int32, (int64*)arg3.strides[1]))) && (10368 == cast(int32, (int64*)arg3.strides[0]))), "arg3.strides: expected to be compact array")
                0
              }
              assert((0u64 == @tir.tvm_struct_get(arg3, 0, 8, dtype=uint64)), "Argument arg3.byte_offset has an unsatisfied constraint: ((uint64)0 == tir.tvm_struct_get(arg3, 0, 8))")
              assert((1 == @tir.tvm_struct_get(arg3, 0, 10, dtype=int32)), "Argument arg3.device_type has an unsatisfied constraint: (1 == tir.tvm_struct_get(arg3, 0, 10))")
              assert((dev_id == @tir.tvm_struct_get(arg3, 0, 9, dtype=int32)), "Argument arg3.device_id has an unsatisfied constraint: (dev_id == tir.tvm_struct_get(arg3, 0, 9))")
              attr [0] "compute_scope" = "myfunc_fusion_compute_";
              attr [R: Pointer(global float32)] "storage_alignment" = 128 {
                let R = @tir.TVMBackendAllocWorkspace(1, dev_id, 2239488u64, 2, 32, dtype=handle)
                 {
                  if @tir.isnullptr(R, dtype=bool) {
                    @tir.tvm_throw_last_error(, dtype=int32)
                  }
                  allocate(B: Pointer(global float32), float32, [1]), storage_scope = global {
                    for (yy: int32, 0, 54) {
                      for (xx: int32, 0, 54) {
                        for (cc: int32, 0, 64) {
                          for (batch: int32, 0, 3) {
                            B[0] = 0f32
                            for (ry: int32, 0, 3) {
                              for (rx: int32, 0, 3) {
                                for (rc: int32, 0, 64) {
                                  B[0] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)A[((((((yy*10752) + (ry*10752)) + (xx*192)) + (rx*192)) + (rc*3)) + batch)], (float32*)W[((((ry*12288) + (rx*4096)) + (rc*64)) + cc)], (float32*)B[0], dtype=float32)
                                }
                              }
                            }
                            R[((((yy*10368) + (xx*192)) + (cc*3)) + batch)] = max(0f32, (float32*)B[0])
                          }
                        }
                      }
                    }
                    for (yy_1: int32, 0, 54) {
                      for (xx_1: int32, 0, 54) {
                        for (ff: int32, 0, 64) {
                          for (nn: int32, 0, 3) {
                            C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = 0f32
                            for (ry_2: int32, 0, 3) {
                              for (rx_2: int32, 0, 3) {
                                for (rc_2: int32, 0, 64) {
                                  C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)] = @tir.call_llvm_pure_intrin(134u32, 3u32, (float32*)R[((((((yy_1*10368) + (ry_2*10368)) + (xx_1*192)) + (rx_2*192)) + (rc_2*3)) + nn)], (float32*)W_2[((((ry_2*12288) + (rx_2*4096)) + (rc_2*64)) + ff)], (float32*)C[((((yy_1*10368) + (xx_1*192)) + (ff*3)) + nn)], dtype=float32)
                                }
                              }
                            }
                          }
                        }
                      }
                    }
                  }
                }
                if (@tir.TVMBackendFreeWorkspace(1, dev_id, R, dtype=int32) != 0) {
                  @tir.tvm_throw_last_error(, dtype=int32)
                }
              }
            }
          }
        }
      }
    }





---
[Visit Topic](https://discuss.tvm.apache.org/t/can-one-reduce-stage-fuse-into-another-reduce-stage/12367/1) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/75e39fd21722d84c7a39212d0f263c73b6db84798a8e4a1929f107fc6ce53cb1).

[Apache TVM Discuss] [Questions] Can One reduce stage fuse into another reduce stage?

Posted by masahi via Apache TVM Discuss <no...@discuss.tvm.ai>.

[quote="jnwang, post:1, topic:12367"]
I would like to know if it is possible to combine two reduce stages (te.sum) into one reduce stage in te
[/quote]

I'm not sure what you mean here, but if I take it literally, that wouldn't be feasible. 

But there is an example of scheduling fused conv2d -> conv2d: https://github.com/apache/tvm/blob/6aa5ba281ba669d01038ca67b2f6d55ba2299249/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py





---
[Visit Topic](https://discuss.tvm.apache.org/t/can-one-reduce-stage-fuse-into-another-reduce-stage/12367/2) to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/c7d0773fc14f01c4028899b5c7d51f1e1290dbf38b7a831919e16004d2cfe359).