You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/07/14 15:59:19 UTC

[incubator-tvm-site] branch master updated: Blog on PyTorch and TVM interop (#13)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm-site.git


The following commit(s) were added to refs/heads/master by this push:
     new d42ee2f  Blog on PyTorch and TVM interop (#13)
d42ee2f is described below

commit d42ee2f53f0027cf3b4626a4dcd963ee2928aa73
Author: Thomas Viehmann <tv...@beamnet.de>
AuthorDate: Tue Jul 14 17:59:09 2020 +0200

    Blog on PyTorch and TVM interop (#13)
---
 _posts/2020-07-14-bert-pytorch-tvm.md             |  540 +++
 images/bert-pytorch/bert-tvm_49_0.svg             |  691 ++++
 images/bert-pytorch/bert-tvm_54_0.svg             |  691 ++++
 images/bert-pytorch/bert-tvm_65_2.svg             |  667 ++++
 images/bert-pytorch/bert-tvm_68_0.svg             |  667 ++++
 images/bert-pytorch/bert-tvm_70_0.svg             |  667 ++++
 images/bert-pytorch/bert-tvm_72_0.svg             |  559 +++
 images/bert-pytorch/bert-tvm_74_0.svg             |  547 +++
 images/bert-pytorch/bert_layer.svg                |  234 ++
 images/bert-pytorch/bert_model.svg                |  325 ++
 images/bert-pytorch/pytorch-tvm-training_20_0.svg | 1237 +++++++
 images/bert-pytorch/pytorch-tvm-training_25_0.svg | 1537 ++++++++
 images/bert-pytorch/pytorch-tvm-training_31_0.svg | 4015 +++++++++++++++++++++
 images/bert-pytorch/pytorch-tvm-training_34_0.svg | 4015 +++++++++++++++++++++
 images/bert-pytorch/pytorch-tvm-training_40_0.svg | 1651 +++++++++
 15 files changed, 18043 insertions(+)

diff --git a/_posts/2020-07-14-bert-pytorch-tvm.md b/_posts/2020-07-14-bert-pytorch-tvm.md
new file mode 100644
index 0000000..7cc5ecb
--- /dev/null
+++ b/_posts/2020-07-14-bert-pytorch-tvm.md
@@ -0,0 +1,540 @@
+---
+layout: post
+title: "Bridging PyTorch and TVM"
+author: "Thomas Viehmann, MathInf GmbH"
+date: 2020-07-14
+---
+{% include JB/setup %}
+
+(A more code-heavy variant is crossposted on the more PyTorch affine [Lernapparat](https://lernapparat.de/transformers-pytorch-tvm/),
+ the Jupyter Notebook to follow along is on [github](https://github.com/t-vi/pytorch-tvmisc/tree/master/transformers-pytorch-tvm/).)
+
+Some of the most intriguing applications of Artificial Intelligence have been in Natural Language Processing.
+Models like BERT or GPT-2 and their variants can seemingly grasp enough of a text to continue it in a way that needs a second look to recognize as gibberish.
+
+These models belong to a class of neural network architectures called *Transformers*. One of the favourite libraries
+implementing them is the [HuggingFace transformers library](https://github.com/huggingface/transformers/).
+
+But, in contrast to convolutional models or LSTMs where we have heavily optimized implementations, this is not as much the case for transformers.
+So here we explore how TVM can fill the gap. We will do so in two steps:
+
+- First we look at BERT inference and tuning that on TVM.
+- Secondly, we start some more fundamental exploration of how one could use TVM for training in PyTorch.
+  Given the experimental nature, we focus on feasibility more than on the performance in this part.
+
+# Optimizing BERT Inference with TVM
+
+So how do we get BERT from the transformer library to TVM?
+
+Helpfully, transformers supports tracing their model with the PyTorch JIT. We use their [tutorial on it](https://huggingface.co/transformers/torchscript.html),
+specifically the part until we have a traced model.
+
+The PyTorch traced model takes around 0.65-0.7 seconds for 100 runs on my AMD Radeon VII with the example inputs, which means 6.5-7ms per run.
+We can try to see if we can use TVM get faster. Let converting our model to TVM is a breeze:
+
+
+```python
+shape_list = [(i.debugName().split('.')[0], i.type().sizes()) for i in  list(traced_model.graph.inputs())[1:]]
+
+mod_bert, params_bert = tvm.relay.frontend.pytorch.from_pytorch(traced_model,
+                        shape_list, default_dtype="float32")
+```
+
+There will be a few warnings about not finding dtype information, but it goes well!
+We can now build and run it. Building follows the standard TVM recipe. We also convert the PyTorch (cpu) tensors to TVM arrays.
+
+
+```python
+target = 'rocm -model=gfx906'  # use what matches your GPU
+
+target_host = 'llvm'
+ctx = tvm.context(target)
+
+tt_a = tvm.nd.array(tokens_tensor.numpy(), ctx)
+st_a = tvm.nd.array(segments_tensors.numpy(), ctx)
+```
+
+
+```python
+tvm.relay.backend.compile_engine.get().clear() # just to be sure, see https://github.com/apache/incubator-tvm/pull/5724
+
+with tvm.transform.PassContext(opt_level=3):
+        graph, lib, params = tvm.relay.build(mod_bert,
+                                     target=target,
+                                     target_host=target_host,
+                                     params=params_bert)
+module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
+```
+
+This will warn us a few times times:
+```
+    WARNING:autotvm:Cannot find config for ... batch_matmul.cuda .... A fallback configuration is used, which may bring great performance regression.
+```
+
+Uh oh, _may bring great performance regression_. Let us see.
+
+But first we run the model and see if the outputs match:
+
+
+```python
+    (8.583069e-06, 8.493662e-07)
+```
+
+Looks good. Remember that we're computing in float32, so $10^{-6}$ish is a good result.
+
+After building our model and setting the parameters, we time our model like this:
+
+```python
+def x():
+    for i in range(100):
+        module.run()
+    ctx.sync()
+x()
+%timeit x()
+```
+
+Ouch, it takes 6.65s per 100 runs, or 67ms per run of the model. That's slow indeed. But the warning said that is was because it could not find (tuned) configurations. Let us then tune the tasks.
+
+Tuning does take half a day or so (I'm basically following the TVM tuning tutorial for ResNet tuning with autotvm.)
+
+After this, we can again build the model, this time with the new configuration. This time we should see no comments about missing configurations.
+Now it's in the region of 6.5-7ms per run, similar to PyTorch. This is what we get from this very elementary optimization of our operators. We can push it a little further, though.
+
+To see how, let us dive deep into BERT modeling and TVM.
+
+If you don't want to get the full details, do skip the next section and scroll down to _Results_. I should add that I would hope that this tuning part of the tutorial will obsolete itself in the sense that in some near future, you will get much better speed right out of the box or at least after some initial tuning. So if you don't see a speedup between here and _Results_, that's because I did my homework in submitting patches.
+
+## The BERT model
+
+Let us take a closer look at what's going on in BERT.
+
+Like many deep learning models, BERT comes with a bit some prologue (vocabulary embeddings) and epilogue (pooling) and the bulk is organized into similar-looking blocks, here we have 12 `BertLayer` modules.
+The `attention_mask` is jsut to prevent BERT from looking at the answer when dealing with the question.
+
+![Bert Model](/images/bert-pytorch/bert_model.svg)
+
+So let us zoom in and look at a BertLayer in detail, since that ultimately is what we need make fast.
+As we see in the net diagram, the main part of the `BertLayer` module is a submodule `BertSelfAttention`.
+
+![BertLayer](/images/bert-pytorch/bert_layer.svg)
+
+Now the `BertSelfAttention` captures the famed self-attention mechanism that is the hallmark of transformer models. (I cannot recommend Sascha Rush's [Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) enough as a detailed walkthrough.)
+
+## Putting the BertLayer under the Microscope
+
+If we want go into details, we should want to run a BertLayer individually.
+We grab the inputs of a BertLayer (see the Notebook for how) and convert a single `BertLayer` to TVM as we did for the entire model.
+
+To look at the TVM module, we define a little visualization helper (loosely based on TVM [PR#4370](https://github.com/apache/incubator-tvm/pull/4370)).
+
+
+```python
+import graphviz
+def visualize(expr, collapse_small=True, node_attr_dict = {}):
+    def collect_ops(node):
+        ops = set()
+        def visitor(e):
+            if isinstance(e, tvm.ir.Op):
+                ops.add(e.name)
+        tvm.relay.analysis.post_order_visit(node, visitor)
+        return ops
+
+    # node_dict maps a Relay node to an index (node ID)
+    def _traverse_expr(node, node_dict):
+        if node in node_dict:
+            return
+        node_dict[node] = len(node_dict)
+
+    node_dict = {}
+    tvm.relay.analysis.post_order_visit(expr, lambda x: _traverse_expr(x, node_dict))
+
+    relayviz_nodes = []
+
+    dot = graphviz.Digraph(format='svg', )
+    dot.attr('node', shape = 'box')
+
+    def to_str(node):
+        if isinstance(node, tvm.relay.Constant):
+            return repr(node).lstrip('Constant(')[:-1]
+        else:
+            raise NotImplementedError("to_str:" + repr(node))
+
+    def is_small_const(c):
+        if not (collapse_small and isinstance(c, tvm.relay.Constant)):
+            return False
+        if isinstance(c.data, tvm.runtime.ndarray.NDArray):
+            return numpy.prod(c.data.shape) < 10
+        return True
+            
+    # Sort by node ID
+    for node, node_id in sorted(node_dict.items(), key=lambda x: x[1]):
+        if isinstance(node, tvm.relay.Function):
+            dot.node(str(node_id), 'Function', **node_attr_dict.get(node, {}))
+            dot.edge(str(node_dict[node.body]), str(node_id))
+        elif isinstance(node, tvm.relay.Var):
+            if node.type_annotation is not None:
+                if hasattr(node.type_annotation, 'shape'):
+                    shape = tuple([int(x) for x in node.type_annotation.shape])
+                    dtype = node.type_annotation.dtype
+                    typstr = 'Tensor[{}, {}]'.format(shape, dtype)
+                else:
+                    typstr = str(node.type_annotation)
+            else:
+                typstr = '?'
+            d = dict(shape = 'ellipse')
+            d.update(node_attr_dict.get(node, {}))
+            dot.node(str(node_id),
+                     '{}: {}'.format(
+                         node.name_hint, typstr
+                     ), **d)
+        elif isinstance(node, tvm.relay.Tuple):
+            dot.node(str(node_id), 'Tuple[...])', **node_attr_dict.get(node, {}))
+            for field in node.fields:
+                dot.edge(str(node_dict[field]), str(node_id))
+        elif isinstance(node, tvm.relay.Constant):
+            
+            if not is_small_const(node): # small consts are shown in ops
+                dot.node(str(node_id), 'Constant({}, {})'.format(node.data.shape, node.data.dtype),
+                        **node_attr_dict.get(node, {}))
+        elif isinstance(node, tvm.relay.Call):
+            args_with_edge = []
+            arg_str_list = []
+            for arg in node.args:
+                if is_small_const(arg):
+                    arg_str_list.append(to_str(arg))
+                else:
+                    arg_str_list.append('·')
+                    args_with_edge.append(arg)
+            arg_str = ', '.join(arg_str_list)
+            if isinstance(node.op, tvm.ir.Op):
+                name = node.op.name
+                attrs = {k:getattr(node.attrs, k) for k in node.attrs.keys()} if hasattr(node.attrs, 'keys') else {}
+                #attrs = inspect.getmembers(node.attrs)
+                attr_str_list = [k+'='+(str(v) if len(str(v))<20 else "...") for k, v in attrs.items()]
+                if attr_str_list:
+                    attr_str = '| '+ ', '.join(attr_str_list)
+                else:
+                    attr_str = ''
+            else:
+                ops = collect_ops(node)
+                if ops:
+                    name = '_'.join(ops)
+                else:
+                    name = '...'
+                attr_str = ''
+            s = f'{name}({arg_str}{attr_str})'
+            dot.node(str(node_id), s, **node_attr_dict.get(node, {}))
+            for arg in args_with_edge:
+                dot.edge(str(node_dict[arg]), str(node_id))
+        elif isinstance(node, tvm.ir.Op):
+            # dot.node(str(node_id), 'Op {}'.format(node.name))
+            pass # covered in call
+        elif isinstance(node, tvm.relay.TupleGetItem):
+            dot.node(str(node_id), 'TupleGetItem(idx={})'.format(node.index), **node_attr_dict.get(node, {}))
+            dot.edge(str(node_dict[node.tuple_value]), str(node_id))
+        elif isinstance(node, tvm.relay.Let):
+            dot.node(str(node_id), 'Let(XX)', **node_attr_dict.get(node, {}))
+            dot.edge(str(node_dict[node.value]), str(node_id))
+            dot.edge(str(node_id), str(node_dict[node.var]))
+        else:
+            raise RuntimeError(
+                'Unknown node type. node_id: {}, node: {}'.format(node_id, type(node)))
+
+    return dot
+
+```
+
+Let's run that on our main function. For some reason (well, to be fully general, probably) the PyTorch converter will convert `Linear` layers to `batch_matmul` rather than just `dense`. We'll get back to this in a bit. As TVM's `batch_matmul` has the contraction axis last on both operands (unlike PyTorch), there are quite a few transpose operations, too.
+
+
+```python
+visualize(mod['main'])
+```
+
+![svg](/images/bert-pytorch/bert-tvm_49_0.svg)
+
+
+In addition to our named inputs, we see a number of unnamed (numbered) variables. These are the neural network parameters.
+
+Let us compile our model.
+
+Just like the full model, we can run and time our submodule after checking that it computes the same quantities. 
+
+100 runs take 20.2ms. The back of the envelope calculation here is that with `BertLayer` in PyTorch we are spending about 0.2ms in this layer, so about 2.4ms on 12 layers - a not the majority but a sizeable part of the 6-7ms overall runtime. Let's compare to TVM. (A good rule is to never optimize without measuring.)
+
+Similarly, TVM clocks in at 18.2ms for 100 runs. So here we are again roughly on par with PyTorch.
+
+One thing we see from the picture is that the input is reshaped three times. There is a TVM optimization pass call Common Subexpression Elimination (CSE) that combines the three reshapes.
+(A while ago, this did not succeed because it had distinct shape arguments, but this was since solved by the TVM developers in the dynamic to static conversion pass.)
+Also, the model parameters that are reshaped and transposed. Can we get rid of that, too? 
+Yes. And for that we would first _bind_ the parameters, i.e. put them into the model. Then the parameters have become constants instead of input nodes. 
+With the `Foldconstant` pass, we can propagate the constants through the `transpose`s and `reshape`s to move them closer to the matmuls.
+
+After these three (which TVM will do when we compile a relay model), our model looks like this:
+
+![svg](/images/bert-pytorch/bert-tvm_72_0.svg)
+
+And now comes an interesting trick. It is more efficient to merge the three batch matmuls with the same input into a single `batch_matmul`. We implemented a pass doing this in [TVM PR 5791](https://github.com/apache/incubator-tvm/pull/5791). So let's call it and also have another constant-folding pass.
+
+
+```python
+new_mod = tvm.relay.transform.CombineParallelBatchMatmul()(new_mod)
+new_mod = tvm.relay.transform.FoldConstant()(new_mod)
+visualize(new_mod["main"])
+```
+
+![svg](/images/bert-pytorch/bert-tvm_74_0.svg)
+
+Awesome. After checking that we still get the same result.
+We can time again: 25.2 ms for 100 runs. It's a bit slow again because we need to tune for the new shapes.
+After tuning, we are at 12.6ms for 100 runs, so we went from about 0.2ms to about 0.13-0.15ms, a nice speedup.
+By our handwavy calculation, this should cut 0.6-0.8ms from the total runtime, or somewhere between 5%-10%. Let's check.
+
+## Results on the overall BERT model after optimization
+
+Let's define a function combining the optimization passes from above and run it on the entire BERT model.
+We go through the same exercise as above.
+
+We get to 624ms for 100 runs. So yay, we went from 6.5-7ms in PyTorch to ~6.2ms in TVM. This is a 5%-10% speedup. Note that we have only taking a particular, not very large shape. A more serious analysis would consider more problem shapes.
+
+We could probably take it a bit further yet - e.g. fusing the additions after the batch matmul by handling the reshape, but we'll leave it at this for now. Also we will benefit from further improvements to TVM, so it will be interesting to see how the benchmark improves over time. In particular, the upcoming Ansor tuning mechanism seems promising.
+
+## A peek under the hood
+
+### Comparing implementation of models
+
+As you can see, I have always compared PyTorch with TVM outputs to see if they're good.
+Also, when I investigated some inner layer, I grabbed the inputs to that to convert and feed into the TVM model. I do believe that this is a very effective technique.
+
+Sometimes, however, it is difficult to assess whether a deviation between the results is from numerical accuracy or from an error somewhere.
+When I initially converted the model, the the `SelfAttention` submodule output was replicated by the TVM model to about 1e-6.
+However, the BertLayer conversion had something like 1-e3. I was not entirely clear whether that might be due to accumulated numerical errors or some material deviation somewhere.
+(This turned out to be the GELU activation, which was converted to FastGELU.)
+
+One of the things I like to do in this case is jump to double precision and check there. Numerical errors should get much smaller, while other deviations would remain of the same order.
+With the PyTorch frontend, you can trace the model converted to float64 on the PyTorch side if you pass `default_dtype="float64"` to the conversion function.
+
+Running the module and comparing to PyTorch should now have 1e-14 or so deviation.
+
+### Improvements in TVM to facilitate this usecase
+
+Before this worked as shown here, we had to close some gaps (but a recent git checkout will include all of them):
+- The TVM PyTorch converter did not support inputs other than fp32. We [implemented improved conversion](https://github.com/t-vi/tvm/tree/pytorch_frontend_type_fix), now also included in TVM upsteam.
+- The TVM schedule, i.e. the organization of the computation, of the workhorse operation, batch_matmul, was fixed and it was very slow (similar to running without a tuned schedule now). So we [implemented a tuneable schedule](https://github.com/apache/incubator-tvm/pull/5752).
+- The PyTorch converter produces batch matmul operations (it could probably also be changed to produce dense layers instead). But as we saw, one of the larger speed advantages is to combine Query Key and Value linear layers, so we implemented [fusing batch matmul operations](https://github.com/apache/incubator-tvm/pull/5791).
+- When comparing the computation results, we noticed that the [GELU](https://pytorch.org/docs/master/generated/torch.nn.GELU.html) function was converted to its FastGELU variant. We fixed that. (There is a _fast math_ optimization pass in TVM that does some replacement of the error function, though we didn't check if it yields FastGELU for the GELU expressed with the error function.)
+- TVM was initially (and still is to a some extent) focussed on static shapes. Recently it experiments with dynamic operations. The dynamic reshape - taking an argument for the target shape - is an early of these experiments, but as seen above, it prevented the fusion of batch matmuls because the common subexpression elimination pass didn't detect that it could merge the identical input reshaping. This has improved recently.
+
+# Training Pytorch models with TVM computation
+
+In this second part we want see if we could use TVM while training BERT in PyTorch.
+Of course, this opens an entire new can of worms as we need to deal with autodifferentiation.
+While we stay with the theme from above and take `BertLayer` as the example, our methodology is representative of non-trivial modules in general.
+We will want to divert the computation during training to TVM.
+
+So the user can take a (traceable) module and do
+```
+add_tvm_dispatch(module, sample_input)
+```
+and then if she calls module with inputs of the same shape as the sample_input, she'll get the outputs computed by TVM (as PyTorch tensors, of course) and if not, it'll just use the regular forward.
+
+The but so we already hinted at the bad news: In this part we will see how to do these things. We will not yet achieve a great speedup.
+
+But enough talk, let us dive right in!
+Again, we get our relay model with running a traced `BertLayer` from the transformer `Bert` model through `tvm.relay.frontend.from_pytorch`.
+
+One thing we'll do in between is to move from a modular interface in PyTorch - with named parameters - to a functional
+interface (which is what TVM can do for us). The first thing we want to do for that is arrange for the function arguments to be in an order that we can work with - i.e. first the direct inputs to the module and then the parameters in the same order that PyTorch uses them. After this operation, our `BertLayer ` in TVM looks like this:
+
+![svg](/images/bert-pytorch/pytorch-tvm-training_20_0.svg)
+
+As in the BERT inference, we want to run some optimization passes.
+
+But we also have a few new transformations:
+
+- One particularity of the Autodifferentiation is that it'll use a lot of `..._like` operations to broadcast or "unbroadcast" (summation is the dual of broadcasting w.r.t. autodifferentiation) things. But this means that you now have two tensor arguments, even if the latter doesn't really need a gradient. `ZappLike` replaces those operations with the corresponding functions taking a shape parameter instead.
+- Another thing is the "rooting" of derivatives. TVM generates a tensors with all ones of the same shape as the return values of our function as the starting point for the chain rule. These are then multiplied to the derivatives of our operations. But multiplication with ones is not doing much, so we strike that. Similarly, TVM initializes the gradient of a variable (an input) to zeros of the same shape. If it isn't used, the gradient will be zero, but if it is, the "real gradient" will  [...]
+- TVM doesn't have a training variant for the `LayerNorm` (or `BatchNorm` or others). So we implement a pass to spell out the computation.
+- TVM also doesn't have training dropout. Here the problem is somewhat harder to fix, as TVM doesn't have random currently. We instead replace the dropout by a construct taking a random bernoulli draw (of 0/1 values) and mimicking dropout with that. The idea is that we'll use PyTorch to generate this mask for us. This has the added benefit that (if we generate dropout masks in the same order as PyTorch) we'll get the exact same result.
+
+As hinted at above, TVM's gradient taking assumes that it is the last element in the computation (the ones-Tensors discussed above). This isn't a good fit with PyTorch's modular view which expects a `grad_out` for each output to be given. Happily, this is computationally equivalent to multiplying by grad out and summation, so we amend our function with that. We wish to be flexible, so we allow both functions returning a single tensor and those returning a tuple of tensors.
+
+With these modificaitons applied, our model looks like this:
+
+![svg](/images/bert-pytorch/pytorch-tvm-training_25_0.svg)
+
+Finally we can take the grad. As we get a lot of `let` nodes, we bring it to normal form using the `ToGraphNormalForm` pass.
+TVM's gradient-taking returns a function that has the same parameters as the original function (in our case amended with the `grad_out` and dropout) and then returns a tuple of the original return and a tuple containing gradients for all inputs.
+The first thing we do is to drop all the gradients for `grad_out` and `dropout` which we don't need.
+Then we run our simplification passes.
+
+So this is the graph we have now for forward and backward:
+
+![svg](/images/bert-pytorch/pytorch-tvm-training_31_0.svg)
+
+But in PyTorch, we first compute the forward and then the backwards, so we have to take out the saw and 
+split our graph. One of the difficult problems is what to do with things computed for both forward and backward. It is a hard problem, related to the MinCut problem.
+
+Our extremal options could be:
+- One could only keep the inputs and recompute everything as needed.
+- If we had a salar output, we could compute the gradient and multiply with the derivative of the later layers on backward. (Loss functions might do that.) This does not, however, work for non-scalar tensor outputs.
+
+We'll do the following: We compute the forward normally, but we keep all things that will be used in the backward. This is too much, unfortunately, and it is very likely the reason we don't see an end to end speedup. We'll discuss some potential heuristics below.
+
+We use a coloring here. First we color all nodes of the forward computation in red. Then we traverse the gradient calculation and then color the nodes it needs from the backward blue. This gives us a chance to show off the attribute support in our visualization.
+
+A bit of (PyTorch) terminology: When we have a function *Layer : x ↦ y* followed by some *Loss: y ↦ l ∈ ℝ*, the backward is *BackwardOfLayer : grad`_`out ↦ grad`_`in* with *grad`_`out = dl/dy* and *grad`_`in = dl/dx`.
+
+![svg](/images/bert-pytorch/pytorch-tvm-training_34_0.svg)
+
+In order to split the function as described above, we collect the blue nodes as to capture - but constants will
+just be duplicated and inputs (`Var` nodes) need to be treated separately.
+Now we can split out the backward, replacing all the blue nodes with variables.
+
+Next we take the forward and amend it to also return the required intermediates. The forward then looks like this:
+
+![svg](/images/bert-pytorch/pytorch-tvm-training_40_0.svg)
+
+TVM cannot return nested tuples, so we flatten the output in the function. Again we differentiate between tensor-valued functions and tuple valued ones (i.e. those returning potentially multiple tensors).
+
+And at last, we can let TVM do its magic and compile our functions, say to `gr_only_compiled_module`
+and `fw_and_cap_compiled_module`.
+Time to give it a spin. We define convenience functions to move tensors between PyTorch and TVM and get the model parameters as a TVM dictionary.
+
+
+```python
+def tensor_to_tvm(t):
+    return tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(t))
+def tensor_from_tvm(a):
+    return(torch.utils.dlpack.from_dlpack(a.to_dlpack()))
+
+model_params_tvm = {k: tensor_to_tvm(v) for k, v in pytorch_model.state_dict().items()}
+```
+
+Similarly, we get the inputs on the GPU in PyTorch and TVM.
+
+We need to deal with the dropout. It will turn out that our record of the three dropout random draws happens in the same order as the dropout in the model. We did a depth-first search on the computational graph to find them and if the values of the the dropout are connected in the graph rather than being on independent branches, this will be the order in which PyTorch draws the matrices, too. If not, good luck fiddeling with the order.
+
+```python
+torch.manual_seed(12345)
+drop_c = {}
+for k in dropout_info.keys(): # we don't know the order
+    p, typ = dropout_info[k]
+    drop_c[k] = torch.nn.functional.dropout(torch.ones([int(i) for i in typ.shape], 
+                                              dtype=getattr(torch, typ.dtype), device="cuda"), p=p)*(1-p)
+
+drop_tvm = {n: tensor_to_tvm(t) for n, t in drop_c.items()}
+```
+
+Now we can run the forward.
+
+```python
+fw_and_cap_compiled_module.set_input('input', inp_tvm[0])
+fw_and_cap_compiled_module.set_input('attention_mask', inp_tvm[1])
+fw_and_cap_compiled_module.set_input(**model_params_tvm)
+fw_and_cap_compiled_module.set_input(**drop_tvm)
+fw_and_cap_compiled_module.run()
+```
+
+And we can compare the output to PyTorch's:
+
+```python
+torch.manual_seed(12345)
+pytorch_model.train()
+res = pytorch_model(*inp_c)[0]
+numpy.abs(fw_and_cap_compiled_module.get_output(0).asnumpy()-res.detach().cpu().numpy()).max()
+```
+
+This gives `2.1457672e-06`.
+
+Supergood. Let's also try the backward. We generate a `grad_out`, set all the variables and run the backward model and run the backward model
+
+
+```python
+gr_out_c = torch.randn(res.shape, device="cuda", dtype=res.dtype)
+```
+
+```python
+num_captures = len(capture_vars)
+num_regular_outputs = len(fw_and_cap_fn_flattened.body.fields) - num_captures
+captured_values = {v.name_hint: fw_and_cap_compiled_module.get_output(num_regular_outputs + i) for i, v in enumerate(capture_vars)}
+
+gr_only_compiled_module.set_input(**drop_tvm)
+gr_only_compiled_module.set_input(**model_params_tvm)
+gr_only_compiled_module.set_input(**captured_values)
+gr_only_compiled_module.set_input('gr:out:0', tensor_to_tvm(gr_out_c))
+gr_only_compiled_module.run()
+```
+
+On the PyTorch side, it is easiest to re-run the forward (remembering to reset the random seed) and get the grads.
+
+
+```python
+torch.manual_seed(12345)
+pytorch_model.train()
+inp_c_rq = [i.requires_grad_() for i in inp_c]
+for p in pytorch_model.parameters():
+    p.requires_grad_()
+res = pytorch_model(*inp_c_rq)[0]
+grads_pt = torch.autograd.grad(res, inp_c_rq + list(pytorch_model.parameters()), gr_out_c, allow_unused=True)
+
+```
+
+Did it work? It seems so:
+
+
+```python
+for i, g_pt in enumerate(grads_pt):
+    print(numpy.abs(gr_only_compiled_module.get_output(i).asnumpy() - g_pt.cpu().numpy()).max())
+```
+
+gives us a list of numbers in the 1e-5ish range.
+
+But we wanted to get something running in PyTorch, right?
+
+Keeping with how PyTorch works, we first define an `autograd.Function` that the things we just did manually:
+ 
+In the `forward`:
+
+- Generate the dropout random values,
+- Run the forward,
+- Record the captures, inputs, and dropout values needed for backward.
+
+In the `backward`, run the backward and return the result (as PyTorch tensors).
+
+With that, we get a PyTorch autograd.Function calling into TVM (we would want a small wrapper for that.
+
+Now all we need to do to achive our goal of getting a method `add_tvm_dispatch(module, sample_inputs)` is
+to trace the module, create the TVM-based autograd function from it and then replace the forward that calls
+that (with the parameters) if applicable or falls back to the usual forward.
+Python's unlimited dynamism makes that kind of hackery relatively easy.
+As all this it is not really TVM-related, we are sparing us that here (but you could check the
+[companion post](https://lernapparat.de/transformers-pytorch-tvm/).
+
+## Performance
+
+As I said in the beginning, we aren't quite where we want to eventually be in terms of performance.
+After tuning the tasks (and on the not very realistic inference example from the HuggingFace BERT + PyTorch JIT tutorial)
+we run 100 iterations of the TVM-enabled BertLayer forward and backward similar to how we did it for the inference.
+One iteration takes 6.2ms going through TVM versus 1.3ms on PyTorch.
+
+So ran our model through TVM all right. But it's not as fast as the usual method yet. Here is to opportunity!
+
+More seriously, we have two immediate paths to improve performance:
+
+- Find a better set of captured nodes.
+- Find optimizations on the TVM graph.
+
+In terms of heuristics for the former (remember that it quite likely NP hard, i.e. I believe it is, but I didn't work out a formal proof),
+one would want to re-do cheap computation, most prominently point-wise computation (or maybe anything but matmul?). But that is for another day.
+
+I hope you enjoyed the tutorial, I look forward to your comments at <tv...@lernapparat.de>.
+
+# Acknowledgements
+
+I had many interesting discussions with HugingFace people and Morgan Funtowicz in particular. Also the TVM contributors had many good comments during the review of the patches TVM and on the forums. The creation of this tutorial was sponsored by AMD.
+
+# Author
+
+[Thomas Viehmann](https://lernapparat.de/) is the founder of [MathInf GmbH](https://mathinf.eu/), Munich, Germany, a boutique training and consultancy firm focusing on Machine Learning and PyTorch.
+He is a PyTorch core developer and co-authored [Deep Learning with PyTorch](https://www.manning.com/books/deep-learning-with-pytorch), which currently available as [free download from the PyTorch website](https://pytorch.org/deep-learning-with-pytorch).
diff --git a/images/bert-pytorch/bert-tvm_49_0.svg b/images/bert-pytorch/bert-tvm_49_0.svg
new file mode 100644
index 0000000..35b0aee
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_49_0.svg
@@ -0,0 +1,691 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="2140pt" height="1916pt"
+ viewBox="0.00 0.00 2140.22 1916.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1912)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1912 2136.22,-1912 2136.22,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="1238.18" cy="-1746" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="1238.18" y="-1742.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 16 -->
+<g id="node9" class="node">
+<title>16</title>
+<polygon fill="none" stroke="black" points="1044.18,-1692 692.18,-1692 692.18,-1656 1044.18,-1656 1044.18,-1692"/>
+<text text-anchor="middle" x="868.18" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;16 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;16</title>
+<path fill="none" stroke="black" d="M1158.81,-1729.98C1102.99,-1719.42 1027.88,-1705.21 968,-1693.89"/>
+<polygon fill="black" stroke="black" points="968.63,-1690.44 958.15,-1692.02 967.33,-1697.32 968.63,-1690.44"/>
+</g>
+<!-- 26 -->
+<g id="node19" class="node">
+<title>26</title>
+<polygon fill="none" stroke="black" points="1414.18,-1692 1062.18,-1692 1062.18,-1656 1414.18,-1656 1414.18,-1692"/>
+<text text-anchor="middle" x="1238.18" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;26 -->
+<g id="edge13" class="edge">
+<title>0&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M1238.18,-1727.7C1238.18,-1719.98 1238.18,-1710.71 1238.18,-1702.11"/>
+<polygon fill="black" stroke="black" points="1241.68,-1702.1 1238.18,-1692.1 1234.68,-1702.1 1241.68,-1702.1"/>
+</g>
+<!-- 47 -->
+<g id="node39" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="1784.18,-1692 1432.18,-1692 1432.18,-1656 1784.18,-1656 1784.18,-1692"/>
+<text text-anchor="middle" x="1608.18" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;47 -->
+<g id="edge37" class="edge">
+<title>0&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1317.56,-1729.98C1373.37,-1719.42 1448.48,-1705.21 1508.36,-1693.89"/>
+<polygon fill="black" stroke="black" points="1509.04,-1697.32 1518.21,-1692.02 1507.74,-1690.44 1509.04,-1697.32"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="200.18" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="200.18" y="-1886.3" font-family="Times,serif" font-size="14.00">query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 17 -->
+<g id="node10" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="310.68,-1836 117.68,-1836 117.68,-1800 310.68,-1800 310.68,-1836"/>
+<text text-anchor="middle" x="214.18" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 1&#45;&gt;17 -->
+<g id="edge2" class="edge">
+<title>1&#45;&gt;17</title>
+<path fill="none" stroke="black" d="M203.64,-1871.7C205.19,-1863.98 207.04,-1854.71 208.76,-1846.11"/>
+<polygon fill="black" stroke="black" points="212.23,-1846.6 210.76,-1836.1 205.37,-1845.22 212.23,-1846.6"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="184.18" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="184.18" y="-1526.3" font-family="Times,serif" font-size="14.00">query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 22 -->
+<g id="node15" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="578.68,-1476 507.68,-1476 507.68,-1440 578.68,-1440 578.68,-1476"/>
+<text text-anchor="middle" x="543.18" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 2&#45;&gt;22 -->
+<g id="edge9" class="edge">
+<title>2&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M261.2,-1513.98C333.06,-1499.97 437.89,-1479.53 497.41,-1467.92"/>
+<polygon fill="black" stroke="black" points="498.13,-1471.35 507.27,-1466 496.79,-1464.48 498.13,-1471.35"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="608.18" cy="-1890" rx="189.57" ry="18"/>
+<text text-anchor="middle" x="608.18" y="-1886.3" font-family="Times,serif" font-size="14.00">key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 27 -->
+<g id="node20" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="700.68,-1836 507.68,-1836 507.68,-1800 700.68,-1800 700.68,-1836"/>
+<text text-anchor="middle" x="604.18" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 3&#45;&gt;27 -->
+<g id="edge14" class="edge">
+<title>3&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M607.19,-1871.7C606.75,-1863.98 606.22,-1854.71 605.73,-1846.11"/>
+<polygon fill="black" stroke="black" points="609.22,-1845.89 605.16,-1836.1 602.24,-1846.29 609.22,-1845.89"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="891.18" cy="-1530" rx="156.77" ry="18"/>
+<text text-anchor="middle" x="891.18" y="-1526.3" font-family="Times,serif" font-size="14.00">key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 32 -->
+<g id="node25" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="1033.68,-1476 962.68,-1476 962.68,-1440 1033.68,-1440 1033.68,-1476"/>
+<text text-anchor="middle" x="998.18" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 4&#45;&gt;32 -->
+<g id="edge21" class="edge">
+<title>4&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M916.81,-1512.23C930.77,-1503.1 948.23,-1491.68 963.4,-1481.76"/>
+<polygon fill="black" stroke="black" points="965.57,-1484.52 972.02,-1476.12 961.73,-1478.66 965.57,-1484.52"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="1350.18" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1350.18" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 42 -->
+<g id="node34" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1385.68,-828 1314.68,-828 1314.68,-792 1385.68,-792 1385.68,-828"/>
+<text text-anchor="middle" x="1350.18" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;42 -->
+<g id="edge32" class="edge">
+<title>5&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1350.18,-863.7C1350.18,-855.98 1350.18,-846.71 1350.18,-838.11"/>
+<polygon fill="black" stroke="black" points="1353.68,-838.1 1350.18,-828.1 1346.68,-838.1 1353.68,-838.1"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="1908.18" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="1908.18" y="-1886.3" font-family="Times,serif" font-size="14.00">value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 48 -->
+<g id="node40" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="2004.68,-1836 1811.68,-1836 1811.68,-1800 2004.68,-1800 2004.68,-1836"/>
+<text text-anchor="middle" x="1908.18" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;48 -->
+<g id="edge38" class="edge">
+<title>6&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1908.18,-1871.7C1908.18,-1863.98 1908.18,-1854.71 1908.18,-1846.11"/>
+<polygon fill="black" stroke="black" points="1911.68,-1846.1 1908.18,-1836.1 1904.68,-1846.1 1911.68,-1846.1"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="1965.18" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="1965.18" y="-1526.3" font-family="Times,serif" font-size="14.00">value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 53 -->
+<g id="node45" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="1641.68,-1476 1570.68,-1476 1570.68,-1440 1641.68,-1440 1641.68,-1476"/>
+<text text-anchor="middle" x="1606.18" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;53 -->
+<g id="edge45" class="edge">
+<title>7&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1888.17,-1513.98C1816.3,-1499.97 1711.47,-1479.53 1651.95,-1467.92"/>
+<polygon fill="black" stroke="black" points="1652.58,-1464.48 1642.09,-1466 1651.24,-1471.35 1652.58,-1464.48"/>
+</g>
+<!-- 20 -->
+<g id="node13" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="627.68,-1620 458.68,-1620 458.68,-1584 627.68,-1584 627.68,-1620"/>
+<text text-anchor="middle" x="543.18" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;20 -->
+<g id="edge5" class="edge">
+<title>16&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M789.09,-1655.97C741.72,-1645.76 681.37,-1632.76 632.31,-1622.2"/>
+<polygon fill="black" stroke="black" points="632.86,-1618.73 622.34,-1620.05 631.38,-1625.58 632.86,-1618.73"/>
+</g>
+<!-- 18 -->
+<g id="node11" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="401.68,-1764 40.68,-1764 40.68,-1728 401.68,-1728 401.68,-1764"/>
+<text text-anchor="middle" x="221.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge3" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M215.91,-1799.7C216.68,-1791.98 217.61,-1782.71 218.47,-1774.11"/>
+<polygon fill="black" stroke="black" points="221.96,-1774.4 219.47,-1764.1 214.99,-1773.71 221.96,-1774.4"/>
+</g>
+<!-- 19 -->
+<g id="node12" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="405.68,-1692 194.68,-1692 194.68,-1656 405.68,-1656 405.68,-1692"/>
+<text text-anchor="middle" x="300.18" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 18&#45;&gt;19 -->
+<g id="edge4" class="edge">
+<title>18&#45;&gt;19</title>
+<path fill="none" stroke="black" d="M240.71,-1727.7C250.56,-1718.97 262.67,-1708.24 273.38,-1698.75"/>
+<polygon fill="black" stroke="black" points="275.72,-1701.36 280.88,-1692.1 271.07,-1696.12 275.72,-1701.36"/>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge6" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M359.32,-1655.97C393.86,-1646.01 437.63,-1633.41 473.8,-1622.99"/>
+<polygon fill="black" stroke="black" points="475.14,-1626.24 483.78,-1620.11 473.21,-1619.52 475.14,-1626.24"/>
+</g>
+<!-- 21 -->
+<g id="node14" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="716.68,-1548 369.68,-1548 369.68,-1512 716.68,-1512 716.68,-1548"/>
+<text text-anchor="middle" x="543.18" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge7" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M543.18,-1583.7C543.18,-1575.98 543.18,-1566.71 543.18,-1558.11"/>
+<polygon fill="black" stroke="black" points="546.68,-1558.1 543.18,-1548.1 539.68,-1558.1 546.68,-1558.1"/>
+</g>
+<!-- 21&#45;&gt;22 -->
+<g id="edge8" class="edge">
+<title>21&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M543.18,-1511.7C543.18,-1503.98 543.18,-1494.71 543.18,-1486.11"/>
+<polygon fill="black" stroke="black" points="546.68,-1486.1 543.18,-1476.1 539.68,-1486.1 546.68,-1486.1"/>
+</g>
+<!-- 23 -->
+<g id="node16" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="797.68,-1404 432.68,-1404 432.68,-1368 797.68,-1368 797.68,-1404"/>
+<text text-anchor="middle" x="615.18" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge10" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M560.98,-1439.7C569.79,-1431.14 580.56,-1420.66 590.18,-1411.3"/>
+<polygon fill="black" stroke="black" points="592.86,-1413.58 597.59,-1404.1 587.98,-1408.57 592.86,-1413.58"/>
+</g>
+<!-- 24 -->
+<g id="node17" class="node">
+<title>24</title>
+<polygon fill="none" stroke="black" points="740.68,-1332 511.68,-1332 511.68,-1296 740.68,-1296 740.68,-1332"/>
+<text text-anchor="middle" x="626.18" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 23&#45;&gt;24 -->
+<g id="edge11" class="edge">
+<title>23&#45;&gt;24</title>
+<path fill="none" stroke="black" d="M617.9,-1367.7C619.11,-1359.98 620.57,-1350.71 621.92,-1342.11"/>
+<polygon fill="black" stroke="black" points="625.4,-1342.53 623.49,-1332.1 618.48,-1341.44 625.4,-1342.53"/>
+</g>
+<!-- 25 -->
+<g id="node18" class="node">
+<title>25</title>
+<polygon fill="none" stroke="black" points="863.68,-1116 520.68,-1116 520.68,-1080 863.68,-1080 863.68,-1116"/>
+<text text-anchor="middle" x="692.18" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 24&#45;&gt;25 -->
+<g id="edge12" class="edge">
+<title>24&#45;&gt;25</title>
+<path fill="none" stroke="black" d="M631.47,-1295.85C642.94,-1258.68 670.15,-1170.44 683.91,-1125.82"/>
+<polygon fill="black" stroke="black" points="687.26,-1126.82 686.87,-1116.23 680.58,-1124.76 687.26,-1126.82"/>
+</g>
+<!-- 38 -->
+<g id="node31" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="1082.68,-1044 913.68,-1044 913.68,-1008 1082.68,-1008 1082.68,-1044"/>
+<text text-anchor="middle" x="998.18" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 25&#45;&gt;38 -->
+<g id="edge27" class="edge">
+<title>25&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M766.65,-1079.97C811.06,-1069.8 867.6,-1056.87 913.68,-1046.33"/>
+<polygon fill="black" stroke="black" points="914.68,-1049.69 923.65,-1044.05 913.12,-1042.87 914.68,-1049.69"/>
+</g>
+<!-- 30 -->
+<g id="node23" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="1322.68,-1620 1153.68,-1620 1153.68,-1584 1322.68,-1584 1322.68,-1620"/>
+<text text-anchor="middle" x="1238.18" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 26&#45;&gt;30 -->
+<g id="edge17" class="edge">
+<title>26&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M1238.18,-1655.7C1238.18,-1647.98 1238.18,-1638.71 1238.18,-1630.11"/>
+<polygon fill="black" stroke="black" points="1241.68,-1630.1 1238.18,-1620.1 1234.68,-1630.1 1241.68,-1630.1"/>
+</g>
+<!-- 28 -->
+<g id="node21" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="782.68,-1764 421.68,-1764 421.68,-1728 782.68,-1728 782.68,-1764"/>
+<text text-anchor="middle" x="602.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 27&#45;&gt;28 -->
+<g id="edge15" class="edge">
+<title>27&#45;&gt;28</title>
+<path fill="none" stroke="black" d="M603.69,-1799.7C603.47,-1791.98 603.2,-1782.71 602.96,-1774.11"/>
+<polygon fill="black" stroke="black" points="606.46,-1774 602.67,-1764.1 599.46,-1774.2 606.46,-1774"/>
+</g>
+<!-- 29 -->
+<g id="node22" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="673.68,-1692 462.68,-1692 462.68,-1656 673.68,-1656 673.68,-1692"/>
+<text text-anchor="middle" x="568.18" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 28&#45;&gt;29 -->
+<g id="edge16" class="edge">
+<title>28&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M593.78,-1727.7C589.91,-1719.73 585.23,-1710.1 580.94,-1701.26"/>
+<polygon fill="black" stroke="black" points="584.01,-1699.57 576.49,-1692.1 577.71,-1702.63 584.01,-1699.57"/>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge18" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M673.75,-1657.18C676.93,-1656.77 680.08,-1656.38 683.18,-1656 844.38,-1636.4 1032.65,-1619.73 1143.28,-1610.56"/>
+<polygon fill="black" stroke="black" points="1143.73,-1614.04 1153.41,-1609.73 1143.16,-1607.06 1143.73,-1614.04"/>
+</g>
+<!-- 31 -->
+<g id="node24" class="node">
+<title>31</title>
+<polygon fill="none" stroke="black" points="1412.68,-1548 1065.68,-1548 1065.68,-1512 1412.68,-1512 1412.68,-1548"/>
+<text text-anchor="middle" x="1239.18" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 30&#45;&gt;31 -->
+<g id="edge19" class="edge">
+<title>30&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M1238.43,-1583.7C1238.54,-1575.98 1238.67,-1566.71 1238.79,-1558.11"/>
+<polygon fill="black" stroke="black" points="1242.29,-1558.15 1238.94,-1548.1 1235.3,-1558.05 1242.29,-1558.15"/>
+</g>
+<!-- 31&#45;&gt;32 -->
+<g id="edge20" class="edge">
+<title>31&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M1180.54,-1511.97C1138.46,-1499.74 1082.57,-1483.51 1043.81,-1472.25"/>
+<polygon fill="black" stroke="black" points="1044.43,-1468.79 1033.85,-1469.36 1042.48,-1475.51 1044.43,-1468.79"/>
+</g>
+<!-- 33 -->
+<g id="node26" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="1180.68,-1404 815.68,-1404 815.68,-1368 1180.68,-1368 1180.68,-1404"/>
+<text text-anchor="middle" x="998.18" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge22" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M998.18,-1439.7C998.18,-1431.98 998.18,-1422.71 998.18,-1414.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1414.1 998.18,-1404.1 994.68,-1414.1 1001.68,-1414.1"/>
+</g>
+<!-- 34 -->
+<g id="node27" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="1112.68,-1332 883.68,-1332 883.68,-1296 1112.68,-1296 1112.68,-1332"/>
+<text text-anchor="middle" x="998.18" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge23" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M998.18,-1367.7C998.18,-1359.98 998.18,-1350.71 998.18,-1342.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1342.1 998.18,-1332.1 994.68,-1342.1 1001.68,-1342.1"/>
+</g>
+<!-- 35 -->
+<g id="node28" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="1112.68,-1260 883.68,-1260 883.68,-1224 1112.68,-1224 1112.68,-1260"/>
+<text text-anchor="middle" x="998.18" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge24" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M998.18,-1295.7C998.18,-1287.98 998.18,-1278.71 998.18,-1270.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1270.1 998.18,-1260.1 994.68,-1270.1 1001.68,-1270.1"/>
+</g>
+<!-- 36 -->
+<g id="node29" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="1169.68,-1188 826.68,-1188 826.68,-1152 1169.68,-1152 1169.68,-1188"/>
+<text text-anchor="middle" x="998.18" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge25" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M998.18,-1223.7C998.18,-1215.98 998.18,-1206.71 998.18,-1198.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1198.1 998.18,-1188.1 994.68,-1198.1 1001.68,-1198.1"/>
+</g>
+<!-- 37 -->
+<g id="node30" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1103.68,-1116 892.68,-1116 892.68,-1080 1103.68,-1080 1103.68,-1116"/>
+<text text-anchor="middle" x="998.18" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge26" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M998.18,-1151.7C998.18,-1143.98 998.18,-1134.71 998.18,-1126.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1126.1 998.18,-1116.1 994.68,-1126.1 1001.68,-1126.1"/>
+</g>
+<!-- 37&#45;&gt;38 -->
+<g id="edge28" class="edge">
+<title>37&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M998.18,-1079.7C998.18,-1071.98 998.18,-1062.71 998.18,-1054.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1054.1 998.18,-1044.1 994.68,-1054.1 1001.68,-1054.1"/>
+</g>
+<!-- 39 -->
+<g id="node32" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="1216.68,-972 851.68,-972 851.68,-936 1216.68,-936 1216.68,-972"/>
+<text text-anchor="middle" x="1034.18" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 38&#45;&gt;39 -->
+<g id="edge29" class="edge">
+<title>38&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M1007.08,-1007.7C1011.18,-999.73 1016.13,-990.1 1020.67,-981.26"/>
+<polygon fill="black" stroke="black" points="1023.92,-982.6 1025.39,-972.1 1017.7,-979.4 1023.92,-982.6"/>
+</g>
+<!-- 41 -->
+<g id="node33" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1105.68,-900 998.68,-900 998.68,-864 1105.68,-864 1105.68,-900"/>
+<text text-anchor="middle" x="1052.18" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 39&#45;&gt;41 -->
+<g id="edge30" class="edge">
+<title>39&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1038.63,-935.7C1040.64,-927.9 1043.05,-918.51 1045.28,-909.83"/>
+<polygon fill="black" stroke="black" points="1048.68,-910.66 1047.78,-900.1 1041.9,-908.92 1048.68,-910.66"/>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge31" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1106.06,-868.07C1111.84,-866.69 1117.63,-865.31 1123.18,-864 1185.98,-849.15 1258.61,-832.24 1304.46,-821.6"/>
+<polygon fill="black" stroke="black" points="1305.44,-824.96 1314.39,-819.29 1303.86,-818.14 1305.44,-824.96"/>
+</g>
+<!-- 43 -->
+<g id="node35" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1445.18,-756 1271.18,-756 1271.18,-720 1445.18,-720 1445.18,-756"/>
+<text text-anchor="middle" x="1358.18" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge33" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1352.16,-791.7C1353.04,-783.98 1354.1,-774.71 1355.08,-766.11"/>
+<polygon fill="black" stroke="black" points="1358.57,-766.44 1356.23,-756.1 1351.61,-765.64 1358.57,-766.44"/>
+</g>
+<!-- 44 -->
+<g id="node36" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1472.68,-684 1289.68,-684 1289.68,-648 1472.68,-648 1472.68,-684"/>
+<text text-anchor="middle" x="1381.18" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge34" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M1363.87,-719.7C1366.43,-711.9 1369.52,-702.51 1372.37,-693.83"/>
+<polygon fill="black" stroke="black" points="1375.77,-694.7 1375.56,-684.1 1369.12,-692.51 1375.77,-694.7"/>
+</g>
+<!-- 45 -->
+<g id="node37" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="1477.18,-612 1309.18,-612 1309.18,-576 1477.18,-576 1477.18,-612"/>
+<text text-anchor="middle" x="1393.18" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 44&#45;&gt;45 -->
+<g id="edge35" class="edge">
+<title>44&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M1384.15,-647.7C1385.47,-639.98 1387.06,-630.71 1388.53,-622.11"/>
+<polygon fill="black" stroke="black" points="1392.01,-622.55 1390.25,-612.1 1385.11,-621.37 1392.01,-622.55"/>
+</g>
+<!-- 46 -->
+<g id="node38" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="1570.68,-540 1227.68,-540 1227.68,-504 1570.68,-504 1570.68,-540"/>
+<text text-anchor="middle" x="1399.18" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 45&#45;&gt;46 -->
+<g id="edge36" class="edge">
+<title>45&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M1394.67,-575.7C1395.33,-567.98 1396.12,-558.71 1396.86,-550.11"/>
+<polygon fill="black" stroke="black" points="1400.35,-550.37 1397.72,-540.1 1393.37,-549.77 1400.35,-550.37"/>
+</g>
+<!-- 58 -->
+<g id="node50" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="1582.68,-468 1413.68,-468 1413.68,-432 1582.68,-432 1582.68,-468"/>
+<text text-anchor="middle" x="1498.18" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 46&#45;&gt;58 -->
+<g id="edge50" class="edge">
+<title>46&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1423.4,-503.88C1436.1,-494.89 1451.85,-483.76 1465.61,-474.03"/>
+<polygon fill="black" stroke="black" points="1467.73,-476.82 1473.87,-468.19 1463.69,-471.11 1467.73,-476.82"/>
+</g>
+<!-- 51 -->
+<g id="node43" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="1692.68,-1620 1523.68,-1620 1523.68,-1584 1692.68,-1584 1692.68,-1620"/>
+<text text-anchor="middle" x="1608.18" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 47&#45;&gt;51 -->
+<g id="edge41" class="edge">
+<title>47&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1608.18,-1655.7C1608.18,-1647.98 1608.18,-1638.71 1608.18,-1630.11"/>
+<polygon fill="black" stroke="black" points="1611.68,-1630.1 1608.18,-1620.1 1604.68,-1630.1 1611.68,-1630.1"/>
+</g>
+<!-- 49 -->
+<g id="node41" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="2088.68,-1764 1727.68,-1764 1727.68,-1728 2088.68,-1728 2088.68,-1764"/>
+<text text-anchor="middle" x="1908.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge39" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1908.18,-1799.7C1908.18,-1791.98 1908.18,-1782.71 1908.18,-1774.11"/>
+<polygon fill="black" stroke="black" points="1911.68,-1774.1 1908.18,-1764.1 1904.68,-1774.1 1911.68,-1774.1"/>
+</g>
+<!-- 50 -->
+<g id="node42" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="2013.68,-1692 1802.68,-1692 1802.68,-1656 2013.68,-1656 2013.68,-1692"/>
+<text text-anchor="middle" x="1908.18" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge40" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M1908.18,-1727.7C1908.18,-1719.98 1908.18,-1710.71 1908.18,-1702.11"/>
+<polygon fill="black" stroke="black" points="1911.68,-1702.1 1908.18,-1692.1 1904.68,-1702.1 1911.68,-1702.1"/>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge42" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1835.18,-1655.97C1791.63,-1645.8 1736.21,-1632.87 1691.03,-1622.33"/>
+<polygon fill="black" stroke="black" points="1691.79,-1618.91 1681.25,-1620.05 1690.2,-1625.73 1691.79,-1618.91"/>
+</g>
+<!-- 52 -->
+<g id="node44" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="1779.68,-1548 1432.68,-1548 1432.68,-1512 1779.68,-1512 1779.68,-1548"/>
+<text text-anchor="middle" x="1606.18" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge43" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M1607.69,-1583.7C1607.47,-1575.98 1607.2,-1566.71 1606.96,-1558.11"/>
+<polygon fill="black" stroke="black" points="1610.46,-1558 1606.67,-1548.1 1603.46,-1558.2 1610.46,-1558"/>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge44" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1606.18,-1511.7C1606.18,-1503.98 1606.18,-1494.71 1606.18,-1486.11"/>
+<polygon fill="black" stroke="black" points="1609.68,-1486.1 1606.18,-1476.1 1602.68,-1486.1 1609.68,-1486.1"/>
+</g>
+<!-- 54 -->
+<g id="node46" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="1787.68,-1404 1422.68,-1404 1422.68,-1368 1787.68,-1368 1787.68,-1404"/>
+<text text-anchor="middle" x="1605.18" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge46" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M1605.94,-1439.7C1605.82,-1431.98 1605.69,-1422.71 1605.57,-1414.11"/>
+<polygon fill="black" stroke="black" points="1609.07,-1414.05 1605.43,-1404.1 1602.07,-1414.15 1609.07,-1414.05"/>
+</g>
+<!-- 55 -->
+<g id="node47" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="1719.68,-1332 1490.68,-1332 1490.68,-1296 1719.68,-1296 1719.68,-1332"/>
+<text text-anchor="middle" x="1605.18" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge47" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M1605.18,-1367.7C1605.18,-1359.98 1605.18,-1350.71 1605.18,-1342.11"/>
+<polygon fill="black" stroke="black" points="1608.68,-1342.1 1605.18,-1332.1 1601.68,-1342.1 1608.68,-1342.1"/>
+</g>
+<!-- 56 -->
+<g id="node48" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="1775.68,-1260 1432.68,-1260 1432.68,-1224 1775.68,-1224 1775.68,-1260"/>
+<text text-anchor="middle" x="1604.18" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge48" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1604.94,-1295.7C1604.82,-1287.98 1604.69,-1278.71 1604.57,-1270.11"/>
+<polygon fill="black" stroke="black" points="1608.07,-1270.05 1604.43,-1260.1 1601.07,-1270.15 1608.07,-1270.05"/>
+</g>
+<!-- 57 -->
+<g id="node49" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="1706.68,-1044 1495.68,-1044 1495.68,-1008 1706.68,-1008 1706.68,-1044"/>
+<text text-anchor="middle" x="1601.18" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge49" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M1603.94,-1223.85C1603.42,-1186.83 1602.19,-1099.18 1601.57,-1054.39"/>
+<polygon fill="black" stroke="black" points="1605.06,-1054.18 1601.42,-1044.23 1598.06,-1054.28 1605.06,-1054.18"/>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge51" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1601.18,-1007.95C1601.18,-981.29 1601.18,-928.11 1601.18,-883 1601.18,-883 1601.18,-883 1601.18,-593 1601.18,-552.36 1603.61,-537.21 1580.18,-504 1571.45,-491.62 1558.82,-481.42 1546.08,-473.4"/>
+<polygon fill="black" stroke="black" points="1547.67,-470.27 1537.29,-468.19 1544.1,-476.3 1547.67,-470.27"/>
+</g>
+<!-- 59 -->
+<g id="node51" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="1680.68,-396 1315.68,-396 1315.68,-360 1680.68,-360 1680.68,-396"/>
+<text text-anchor="middle" x="1498.18" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge52" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M1498.18,-431.7C1498.18,-423.98 1498.18,-414.71 1498.18,-406.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-406.1 1498.18,-396.1 1494.68,-406.1 1501.68,-406.1"/>
+</g>
+<!-- 60 -->
+<g id="node52" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="1612.68,-324 1383.68,-324 1383.68,-288 1612.68,-288 1612.68,-324"/>
+<text text-anchor="middle" x="1498.18" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge53" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1498.18,-359.7C1498.18,-351.98 1498.18,-342.71 1498.18,-334.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-334.1 1498.18,-324.1 1494.68,-334.1 1501.68,-334.1"/>
+</g>
+<!-- 61 -->
+<g id="node53" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="1530.68,-252 1465.68,-252 1465.68,-216 1530.68,-216 1530.68,-252"/>
+<text text-anchor="middle" x="1498.18" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 60&#45;&gt;61 -->
+<g id="edge54" class="edge">
+<title>60&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M1498.18,-287.7C1498.18,-279.98 1498.18,-270.71 1498.18,-262.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-262.1 1498.18,-252.1 1494.68,-262.1 1501.68,-262.1"/>
+</g>
+<!-- 62 -->
+<g id="node54" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="1671.68,-180 1324.68,-180 1324.68,-144 1671.68,-144 1671.68,-180"/>
+<text text-anchor="middle" x="1498.18" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge55" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M1498.18,-215.7C1498.18,-207.98 1498.18,-198.71 1498.18,-190.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-190.1 1498.18,-180.1 1494.68,-190.1 1501.68,-190.1"/>
+</g>
+<!-- 63 -->
+<g id="node55" class="node">
+<title>63</title>
+<polygon fill="none" stroke="black" points="1541.18,-108 1455.18,-108 1455.18,-72 1541.18,-72 1541.18,-108"/>
+<text text-anchor="middle" x="1498.18" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 62&#45;&gt;63 -->
+<g id="edge56" class="edge">
+<title>62&#45;&gt;63</title>
+<path fill="none" stroke="black" d="M1498.18,-143.7C1498.18,-135.98 1498.18,-126.71 1498.18,-118.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-118.1 1498.18,-108.1 1494.68,-118.1 1501.68,-118.1"/>
+</g>
+<!-- 64 -->
+<g id="node56" class="node">
+<title>64</title>
+<polygon fill="none" stroke="black" points="1538.18,-36 1458.18,-36 1458.18,0 1538.18,0 1538.18,-36"/>
+<text text-anchor="middle" x="1498.18" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 63&#45;&gt;64 -->
+<g id="edge57" class="edge">
+<title>63&#45;&gt;64</title>
+<path fill="none" stroke="black" d="M1498.18,-71.7C1498.18,-63.98 1498.18,-54.71 1498.18,-46.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-46.1 1498.18,-36.1 1494.68,-46.1 1501.68,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert-tvm_54_0.svg b/images/bert-pytorch/bert-tvm_54_0.svg
new file mode 100644
index 0000000..35b0aee
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_54_0.svg
@@ -0,0 +1,691 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="2140pt" height="1916pt"
+ viewBox="0.00 0.00 2140.22 1916.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1912)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1912 2136.22,-1912 2136.22,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="1238.18" cy="-1746" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="1238.18" y="-1742.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 16 -->
+<g id="node9" class="node">
+<title>16</title>
+<polygon fill="none" stroke="black" points="1044.18,-1692 692.18,-1692 692.18,-1656 1044.18,-1656 1044.18,-1692"/>
+<text text-anchor="middle" x="868.18" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;16 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;16</title>
+<path fill="none" stroke="black" d="M1158.81,-1729.98C1102.99,-1719.42 1027.88,-1705.21 968,-1693.89"/>
+<polygon fill="black" stroke="black" points="968.63,-1690.44 958.15,-1692.02 967.33,-1697.32 968.63,-1690.44"/>
+</g>
+<!-- 26 -->
+<g id="node19" class="node">
+<title>26</title>
+<polygon fill="none" stroke="black" points="1414.18,-1692 1062.18,-1692 1062.18,-1656 1414.18,-1656 1414.18,-1692"/>
+<text text-anchor="middle" x="1238.18" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;26 -->
+<g id="edge13" class="edge">
+<title>0&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M1238.18,-1727.7C1238.18,-1719.98 1238.18,-1710.71 1238.18,-1702.11"/>
+<polygon fill="black" stroke="black" points="1241.68,-1702.1 1238.18,-1692.1 1234.68,-1702.1 1241.68,-1702.1"/>
+</g>
+<!-- 47 -->
+<g id="node39" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="1784.18,-1692 1432.18,-1692 1432.18,-1656 1784.18,-1656 1784.18,-1692"/>
+<text text-anchor="middle" x="1608.18" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;47 -->
+<g id="edge37" class="edge">
+<title>0&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1317.56,-1729.98C1373.37,-1719.42 1448.48,-1705.21 1508.36,-1693.89"/>
+<polygon fill="black" stroke="black" points="1509.04,-1697.32 1518.21,-1692.02 1507.74,-1690.44 1509.04,-1697.32"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="200.18" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="200.18" y="-1886.3" font-family="Times,serif" font-size="14.00">query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 17 -->
+<g id="node10" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="310.68,-1836 117.68,-1836 117.68,-1800 310.68,-1800 310.68,-1836"/>
+<text text-anchor="middle" x="214.18" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 1&#45;&gt;17 -->
+<g id="edge2" class="edge">
+<title>1&#45;&gt;17</title>
+<path fill="none" stroke="black" d="M203.64,-1871.7C205.19,-1863.98 207.04,-1854.71 208.76,-1846.11"/>
+<polygon fill="black" stroke="black" points="212.23,-1846.6 210.76,-1836.1 205.37,-1845.22 212.23,-1846.6"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="184.18" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="184.18" y="-1526.3" font-family="Times,serif" font-size="14.00">query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 22 -->
+<g id="node15" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="578.68,-1476 507.68,-1476 507.68,-1440 578.68,-1440 578.68,-1476"/>
+<text text-anchor="middle" x="543.18" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 2&#45;&gt;22 -->
+<g id="edge9" class="edge">
+<title>2&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M261.2,-1513.98C333.06,-1499.97 437.89,-1479.53 497.41,-1467.92"/>
+<polygon fill="black" stroke="black" points="498.13,-1471.35 507.27,-1466 496.79,-1464.48 498.13,-1471.35"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="608.18" cy="-1890" rx="189.57" ry="18"/>
+<text text-anchor="middle" x="608.18" y="-1886.3" font-family="Times,serif" font-size="14.00">key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 27 -->
+<g id="node20" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="700.68,-1836 507.68,-1836 507.68,-1800 700.68,-1800 700.68,-1836"/>
+<text text-anchor="middle" x="604.18" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 3&#45;&gt;27 -->
+<g id="edge14" class="edge">
+<title>3&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M607.19,-1871.7C606.75,-1863.98 606.22,-1854.71 605.73,-1846.11"/>
+<polygon fill="black" stroke="black" points="609.22,-1845.89 605.16,-1836.1 602.24,-1846.29 609.22,-1845.89"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="891.18" cy="-1530" rx="156.77" ry="18"/>
+<text text-anchor="middle" x="891.18" y="-1526.3" font-family="Times,serif" font-size="14.00">key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 32 -->
+<g id="node25" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="1033.68,-1476 962.68,-1476 962.68,-1440 1033.68,-1440 1033.68,-1476"/>
+<text text-anchor="middle" x="998.18" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 4&#45;&gt;32 -->
+<g id="edge21" class="edge">
+<title>4&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M916.81,-1512.23C930.77,-1503.1 948.23,-1491.68 963.4,-1481.76"/>
+<polygon fill="black" stroke="black" points="965.57,-1484.52 972.02,-1476.12 961.73,-1478.66 965.57,-1484.52"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="1350.18" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1350.18" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 42 -->
+<g id="node34" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1385.68,-828 1314.68,-828 1314.68,-792 1385.68,-792 1385.68,-828"/>
+<text text-anchor="middle" x="1350.18" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;42 -->
+<g id="edge32" class="edge">
+<title>5&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1350.18,-863.7C1350.18,-855.98 1350.18,-846.71 1350.18,-838.11"/>
+<polygon fill="black" stroke="black" points="1353.68,-838.1 1350.18,-828.1 1346.68,-838.1 1353.68,-838.1"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="1908.18" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="1908.18" y="-1886.3" font-family="Times,serif" font-size="14.00">value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 48 -->
+<g id="node40" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="2004.68,-1836 1811.68,-1836 1811.68,-1800 2004.68,-1800 2004.68,-1836"/>
+<text text-anchor="middle" x="1908.18" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;48 -->
+<g id="edge38" class="edge">
+<title>6&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1908.18,-1871.7C1908.18,-1863.98 1908.18,-1854.71 1908.18,-1846.11"/>
+<polygon fill="black" stroke="black" points="1911.68,-1846.1 1908.18,-1836.1 1904.68,-1846.1 1911.68,-1846.1"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="1965.18" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="1965.18" y="-1526.3" font-family="Times,serif" font-size="14.00">value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 53 -->
+<g id="node45" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="1641.68,-1476 1570.68,-1476 1570.68,-1440 1641.68,-1440 1641.68,-1476"/>
+<text text-anchor="middle" x="1606.18" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;53 -->
+<g id="edge45" class="edge">
+<title>7&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1888.17,-1513.98C1816.3,-1499.97 1711.47,-1479.53 1651.95,-1467.92"/>
+<polygon fill="black" stroke="black" points="1652.58,-1464.48 1642.09,-1466 1651.24,-1471.35 1652.58,-1464.48"/>
+</g>
+<!-- 20 -->
+<g id="node13" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="627.68,-1620 458.68,-1620 458.68,-1584 627.68,-1584 627.68,-1620"/>
+<text text-anchor="middle" x="543.18" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;20 -->
+<g id="edge5" class="edge">
+<title>16&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M789.09,-1655.97C741.72,-1645.76 681.37,-1632.76 632.31,-1622.2"/>
+<polygon fill="black" stroke="black" points="632.86,-1618.73 622.34,-1620.05 631.38,-1625.58 632.86,-1618.73"/>
+</g>
+<!-- 18 -->
+<g id="node11" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="401.68,-1764 40.68,-1764 40.68,-1728 401.68,-1728 401.68,-1764"/>
+<text text-anchor="middle" x="221.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge3" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M215.91,-1799.7C216.68,-1791.98 217.61,-1782.71 218.47,-1774.11"/>
+<polygon fill="black" stroke="black" points="221.96,-1774.4 219.47,-1764.1 214.99,-1773.71 221.96,-1774.4"/>
+</g>
+<!-- 19 -->
+<g id="node12" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="405.68,-1692 194.68,-1692 194.68,-1656 405.68,-1656 405.68,-1692"/>
+<text text-anchor="middle" x="300.18" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 18&#45;&gt;19 -->
+<g id="edge4" class="edge">
+<title>18&#45;&gt;19</title>
+<path fill="none" stroke="black" d="M240.71,-1727.7C250.56,-1718.97 262.67,-1708.24 273.38,-1698.75"/>
+<polygon fill="black" stroke="black" points="275.72,-1701.36 280.88,-1692.1 271.07,-1696.12 275.72,-1701.36"/>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge6" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M359.32,-1655.97C393.86,-1646.01 437.63,-1633.41 473.8,-1622.99"/>
+<polygon fill="black" stroke="black" points="475.14,-1626.24 483.78,-1620.11 473.21,-1619.52 475.14,-1626.24"/>
+</g>
+<!-- 21 -->
+<g id="node14" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="716.68,-1548 369.68,-1548 369.68,-1512 716.68,-1512 716.68,-1548"/>
+<text text-anchor="middle" x="543.18" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge7" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M543.18,-1583.7C543.18,-1575.98 543.18,-1566.71 543.18,-1558.11"/>
+<polygon fill="black" stroke="black" points="546.68,-1558.1 543.18,-1548.1 539.68,-1558.1 546.68,-1558.1"/>
+</g>
+<!-- 21&#45;&gt;22 -->
+<g id="edge8" class="edge">
+<title>21&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M543.18,-1511.7C543.18,-1503.98 543.18,-1494.71 543.18,-1486.11"/>
+<polygon fill="black" stroke="black" points="546.68,-1486.1 543.18,-1476.1 539.68,-1486.1 546.68,-1486.1"/>
+</g>
+<!-- 23 -->
+<g id="node16" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="797.68,-1404 432.68,-1404 432.68,-1368 797.68,-1368 797.68,-1404"/>
+<text text-anchor="middle" x="615.18" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge10" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M560.98,-1439.7C569.79,-1431.14 580.56,-1420.66 590.18,-1411.3"/>
+<polygon fill="black" stroke="black" points="592.86,-1413.58 597.59,-1404.1 587.98,-1408.57 592.86,-1413.58"/>
+</g>
+<!-- 24 -->
+<g id="node17" class="node">
+<title>24</title>
+<polygon fill="none" stroke="black" points="740.68,-1332 511.68,-1332 511.68,-1296 740.68,-1296 740.68,-1332"/>
+<text text-anchor="middle" x="626.18" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 23&#45;&gt;24 -->
+<g id="edge11" class="edge">
+<title>23&#45;&gt;24</title>
+<path fill="none" stroke="black" d="M617.9,-1367.7C619.11,-1359.98 620.57,-1350.71 621.92,-1342.11"/>
+<polygon fill="black" stroke="black" points="625.4,-1342.53 623.49,-1332.1 618.48,-1341.44 625.4,-1342.53"/>
+</g>
+<!-- 25 -->
+<g id="node18" class="node">
+<title>25</title>
+<polygon fill="none" stroke="black" points="863.68,-1116 520.68,-1116 520.68,-1080 863.68,-1080 863.68,-1116"/>
+<text text-anchor="middle" x="692.18" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 24&#45;&gt;25 -->
+<g id="edge12" class="edge">
+<title>24&#45;&gt;25</title>
+<path fill="none" stroke="black" d="M631.47,-1295.85C642.94,-1258.68 670.15,-1170.44 683.91,-1125.82"/>
+<polygon fill="black" stroke="black" points="687.26,-1126.82 686.87,-1116.23 680.58,-1124.76 687.26,-1126.82"/>
+</g>
+<!-- 38 -->
+<g id="node31" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="1082.68,-1044 913.68,-1044 913.68,-1008 1082.68,-1008 1082.68,-1044"/>
+<text text-anchor="middle" x="998.18" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 25&#45;&gt;38 -->
+<g id="edge27" class="edge">
+<title>25&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M766.65,-1079.97C811.06,-1069.8 867.6,-1056.87 913.68,-1046.33"/>
+<polygon fill="black" stroke="black" points="914.68,-1049.69 923.65,-1044.05 913.12,-1042.87 914.68,-1049.69"/>
+</g>
+<!-- 30 -->
+<g id="node23" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="1322.68,-1620 1153.68,-1620 1153.68,-1584 1322.68,-1584 1322.68,-1620"/>
+<text text-anchor="middle" x="1238.18" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 26&#45;&gt;30 -->
+<g id="edge17" class="edge">
+<title>26&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M1238.18,-1655.7C1238.18,-1647.98 1238.18,-1638.71 1238.18,-1630.11"/>
+<polygon fill="black" stroke="black" points="1241.68,-1630.1 1238.18,-1620.1 1234.68,-1630.1 1241.68,-1630.1"/>
+</g>
+<!-- 28 -->
+<g id="node21" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="782.68,-1764 421.68,-1764 421.68,-1728 782.68,-1728 782.68,-1764"/>
+<text text-anchor="middle" x="602.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 27&#45;&gt;28 -->
+<g id="edge15" class="edge">
+<title>27&#45;&gt;28</title>
+<path fill="none" stroke="black" d="M603.69,-1799.7C603.47,-1791.98 603.2,-1782.71 602.96,-1774.11"/>
+<polygon fill="black" stroke="black" points="606.46,-1774 602.67,-1764.1 599.46,-1774.2 606.46,-1774"/>
+</g>
+<!-- 29 -->
+<g id="node22" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="673.68,-1692 462.68,-1692 462.68,-1656 673.68,-1656 673.68,-1692"/>
+<text text-anchor="middle" x="568.18" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 28&#45;&gt;29 -->
+<g id="edge16" class="edge">
+<title>28&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M593.78,-1727.7C589.91,-1719.73 585.23,-1710.1 580.94,-1701.26"/>
+<polygon fill="black" stroke="black" points="584.01,-1699.57 576.49,-1692.1 577.71,-1702.63 584.01,-1699.57"/>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge18" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M673.75,-1657.18C676.93,-1656.77 680.08,-1656.38 683.18,-1656 844.38,-1636.4 1032.65,-1619.73 1143.28,-1610.56"/>
+<polygon fill="black" stroke="black" points="1143.73,-1614.04 1153.41,-1609.73 1143.16,-1607.06 1143.73,-1614.04"/>
+</g>
+<!-- 31 -->
+<g id="node24" class="node">
+<title>31</title>
+<polygon fill="none" stroke="black" points="1412.68,-1548 1065.68,-1548 1065.68,-1512 1412.68,-1512 1412.68,-1548"/>
+<text text-anchor="middle" x="1239.18" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 30&#45;&gt;31 -->
+<g id="edge19" class="edge">
+<title>30&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M1238.43,-1583.7C1238.54,-1575.98 1238.67,-1566.71 1238.79,-1558.11"/>
+<polygon fill="black" stroke="black" points="1242.29,-1558.15 1238.94,-1548.1 1235.3,-1558.05 1242.29,-1558.15"/>
+</g>
+<!-- 31&#45;&gt;32 -->
+<g id="edge20" class="edge">
+<title>31&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M1180.54,-1511.97C1138.46,-1499.74 1082.57,-1483.51 1043.81,-1472.25"/>
+<polygon fill="black" stroke="black" points="1044.43,-1468.79 1033.85,-1469.36 1042.48,-1475.51 1044.43,-1468.79"/>
+</g>
+<!-- 33 -->
+<g id="node26" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="1180.68,-1404 815.68,-1404 815.68,-1368 1180.68,-1368 1180.68,-1404"/>
+<text text-anchor="middle" x="998.18" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge22" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M998.18,-1439.7C998.18,-1431.98 998.18,-1422.71 998.18,-1414.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1414.1 998.18,-1404.1 994.68,-1414.1 1001.68,-1414.1"/>
+</g>
+<!-- 34 -->
+<g id="node27" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="1112.68,-1332 883.68,-1332 883.68,-1296 1112.68,-1296 1112.68,-1332"/>
+<text text-anchor="middle" x="998.18" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge23" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M998.18,-1367.7C998.18,-1359.98 998.18,-1350.71 998.18,-1342.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1342.1 998.18,-1332.1 994.68,-1342.1 1001.68,-1342.1"/>
+</g>
+<!-- 35 -->
+<g id="node28" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="1112.68,-1260 883.68,-1260 883.68,-1224 1112.68,-1224 1112.68,-1260"/>
+<text text-anchor="middle" x="998.18" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge24" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M998.18,-1295.7C998.18,-1287.98 998.18,-1278.71 998.18,-1270.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1270.1 998.18,-1260.1 994.68,-1270.1 1001.68,-1270.1"/>
+</g>
+<!-- 36 -->
+<g id="node29" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="1169.68,-1188 826.68,-1188 826.68,-1152 1169.68,-1152 1169.68,-1188"/>
+<text text-anchor="middle" x="998.18" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge25" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M998.18,-1223.7C998.18,-1215.98 998.18,-1206.71 998.18,-1198.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1198.1 998.18,-1188.1 994.68,-1198.1 1001.68,-1198.1"/>
+</g>
+<!-- 37 -->
+<g id="node30" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1103.68,-1116 892.68,-1116 892.68,-1080 1103.68,-1080 1103.68,-1116"/>
+<text text-anchor="middle" x="998.18" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge26" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M998.18,-1151.7C998.18,-1143.98 998.18,-1134.71 998.18,-1126.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1126.1 998.18,-1116.1 994.68,-1126.1 1001.68,-1126.1"/>
+</g>
+<!-- 37&#45;&gt;38 -->
+<g id="edge28" class="edge">
+<title>37&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M998.18,-1079.7C998.18,-1071.98 998.18,-1062.71 998.18,-1054.11"/>
+<polygon fill="black" stroke="black" points="1001.68,-1054.1 998.18,-1044.1 994.68,-1054.1 1001.68,-1054.1"/>
+</g>
+<!-- 39 -->
+<g id="node32" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="1216.68,-972 851.68,-972 851.68,-936 1216.68,-936 1216.68,-972"/>
+<text text-anchor="middle" x="1034.18" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 38&#45;&gt;39 -->
+<g id="edge29" class="edge">
+<title>38&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M1007.08,-1007.7C1011.18,-999.73 1016.13,-990.1 1020.67,-981.26"/>
+<polygon fill="black" stroke="black" points="1023.92,-982.6 1025.39,-972.1 1017.7,-979.4 1023.92,-982.6"/>
+</g>
+<!-- 41 -->
+<g id="node33" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1105.68,-900 998.68,-900 998.68,-864 1105.68,-864 1105.68,-900"/>
+<text text-anchor="middle" x="1052.18" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 39&#45;&gt;41 -->
+<g id="edge30" class="edge">
+<title>39&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1038.63,-935.7C1040.64,-927.9 1043.05,-918.51 1045.28,-909.83"/>
+<polygon fill="black" stroke="black" points="1048.68,-910.66 1047.78,-900.1 1041.9,-908.92 1048.68,-910.66"/>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge31" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1106.06,-868.07C1111.84,-866.69 1117.63,-865.31 1123.18,-864 1185.98,-849.15 1258.61,-832.24 1304.46,-821.6"/>
+<polygon fill="black" stroke="black" points="1305.44,-824.96 1314.39,-819.29 1303.86,-818.14 1305.44,-824.96"/>
+</g>
+<!-- 43 -->
+<g id="node35" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1445.18,-756 1271.18,-756 1271.18,-720 1445.18,-720 1445.18,-756"/>
+<text text-anchor="middle" x="1358.18" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge33" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1352.16,-791.7C1353.04,-783.98 1354.1,-774.71 1355.08,-766.11"/>
+<polygon fill="black" stroke="black" points="1358.57,-766.44 1356.23,-756.1 1351.61,-765.64 1358.57,-766.44"/>
+</g>
+<!-- 44 -->
+<g id="node36" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1472.68,-684 1289.68,-684 1289.68,-648 1472.68,-648 1472.68,-684"/>
+<text text-anchor="middle" x="1381.18" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge34" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M1363.87,-719.7C1366.43,-711.9 1369.52,-702.51 1372.37,-693.83"/>
+<polygon fill="black" stroke="black" points="1375.77,-694.7 1375.56,-684.1 1369.12,-692.51 1375.77,-694.7"/>
+</g>
+<!-- 45 -->
+<g id="node37" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="1477.18,-612 1309.18,-612 1309.18,-576 1477.18,-576 1477.18,-612"/>
+<text text-anchor="middle" x="1393.18" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 44&#45;&gt;45 -->
+<g id="edge35" class="edge">
+<title>44&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M1384.15,-647.7C1385.47,-639.98 1387.06,-630.71 1388.53,-622.11"/>
+<polygon fill="black" stroke="black" points="1392.01,-622.55 1390.25,-612.1 1385.11,-621.37 1392.01,-622.55"/>
+</g>
+<!-- 46 -->
+<g id="node38" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="1570.68,-540 1227.68,-540 1227.68,-504 1570.68,-504 1570.68,-540"/>
+<text text-anchor="middle" x="1399.18" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 45&#45;&gt;46 -->
+<g id="edge36" class="edge">
+<title>45&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M1394.67,-575.7C1395.33,-567.98 1396.12,-558.71 1396.86,-550.11"/>
+<polygon fill="black" stroke="black" points="1400.35,-550.37 1397.72,-540.1 1393.37,-549.77 1400.35,-550.37"/>
+</g>
+<!-- 58 -->
+<g id="node50" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="1582.68,-468 1413.68,-468 1413.68,-432 1582.68,-432 1582.68,-468"/>
+<text text-anchor="middle" x="1498.18" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 46&#45;&gt;58 -->
+<g id="edge50" class="edge">
+<title>46&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1423.4,-503.88C1436.1,-494.89 1451.85,-483.76 1465.61,-474.03"/>
+<polygon fill="black" stroke="black" points="1467.73,-476.82 1473.87,-468.19 1463.69,-471.11 1467.73,-476.82"/>
+</g>
+<!-- 51 -->
+<g id="node43" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="1692.68,-1620 1523.68,-1620 1523.68,-1584 1692.68,-1584 1692.68,-1620"/>
+<text text-anchor="middle" x="1608.18" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 47&#45;&gt;51 -->
+<g id="edge41" class="edge">
+<title>47&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1608.18,-1655.7C1608.18,-1647.98 1608.18,-1638.71 1608.18,-1630.11"/>
+<polygon fill="black" stroke="black" points="1611.68,-1630.1 1608.18,-1620.1 1604.68,-1630.1 1611.68,-1630.1"/>
+</g>
+<!-- 49 -->
+<g id="node41" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="2088.68,-1764 1727.68,-1764 1727.68,-1728 2088.68,-1728 2088.68,-1764"/>
+<text text-anchor="middle" x="1908.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge39" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1908.18,-1799.7C1908.18,-1791.98 1908.18,-1782.71 1908.18,-1774.11"/>
+<polygon fill="black" stroke="black" points="1911.68,-1774.1 1908.18,-1764.1 1904.68,-1774.1 1911.68,-1774.1"/>
+</g>
+<!-- 50 -->
+<g id="node42" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="2013.68,-1692 1802.68,-1692 1802.68,-1656 2013.68,-1656 2013.68,-1692"/>
+<text text-anchor="middle" x="1908.18" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge40" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M1908.18,-1727.7C1908.18,-1719.98 1908.18,-1710.71 1908.18,-1702.11"/>
+<polygon fill="black" stroke="black" points="1911.68,-1702.1 1908.18,-1692.1 1904.68,-1702.1 1911.68,-1702.1"/>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge42" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1835.18,-1655.97C1791.63,-1645.8 1736.21,-1632.87 1691.03,-1622.33"/>
+<polygon fill="black" stroke="black" points="1691.79,-1618.91 1681.25,-1620.05 1690.2,-1625.73 1691.79,-1618.91"/>
+</g>
+<!-- 52 -->
+<g id="node44" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="1779.68,-1548 1432.68,-1548 1432.68,-1512 1779.68,-1512 1779.68,-1548"/>
+<text text-anchor="middle" x="1606.18" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge43" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M1607.69,-1583.7C1607.47,-1575.98 1607.2,-1566.71 1606.96,-1558.11"/>
+<polygon fill="black" stroke="black" points="1610.46,-1558 1606.67,-1548.1 1603.46,-1558.2 1610.46,-1558"/>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge44" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1606.18,-1511.7C1606.18,-1503.98 1606.18,-1494.71 1606.18,-1486.11"/>
+<polygon fill="black" stroke="black" points="1609.68,-1486.1 1606.18,-1476.1 1602.68,-1486.1 1609.68,-1486.1"/>
+</g>
+<!-- 54 -->
+<g id="node46" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="1787.68,-1404 1422.68,-1404 1422.68,-1368 1787.68,-1368 1787.68,-1404"/>
+<text text-anchor="middle" x="1605.18" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge46" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M1605.94,-1439.7C1605.82,-1431.98 1605.69,-1422.71 1605.57,-1414.11"/>
+<polygon fill="black" stroke="black" points="1609.07,-1414.05 1605.43,-1404.1 1602.07,-1414.15 1609.07,-1414.05"/>
+</g>
+<!-- 55 -->
+<g id="node47" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="1719.68,-1332 1490.68,-1332 1490.68,-1296 1719.68,-1296 1719.68,-1332"/>
+<text text-anchor="middle" x="1605.18" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge47" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M1605.18,-1367.7C1605.18,-1359.98 1605.18,-1350.71 1605.18,-1342.11"/>
+<polygon fill="black" stroke="black" points="1608.68,-1342.1 1605.18,-1332.1 1601.68,-1342.1 1608.68,-1342.1"/>
+</g>
+<!-- 56 -->
+<g id="node48" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="1775.68,-1260 1432.68,-1260 1432.68,-1224 1775.68,-1224 1775.68,-1260"/>
+<text text-anchor="middle" x="1604.18" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge48" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1604.94,-1295.7C1604.82,-1287.98 1604.69,-1278.71 1604.57,-1270.11"/>
+<polygon fill="black" stroke="black" points="1608.07,-1270.05 1604.43,-1260.1 1601.07,-1270.15 1608.07,-1270.05"/>
+</g>
+<!-- 57 -->
+<g id="node49" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="1706.68,-1044 1495.68,-1044 1495.68,-1008 1706.68,-1008 1706.68,-1044"/>
+<text text-anchor="middle" x="1601.18" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge49" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M1603.94,-1223.85C1603.42,-1186.83 1602.19,-1099.18 1601.57,-1054.39"/>
+<polygon fill="black" stroke="black" points="1605.06,-1054.18 1601.42,-1044.23 1598.06,-1054.28 1605.06,-1054.18"/>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge51" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1601.18,-1007.95C1601.18,-981.29 1601.18,-928.11 1601.18,-883 1601.18,-883 1601.18,-883 1601.18,-593 1601.18,-552.36 1603.61,-537.21 1580.18,-504 1571.45,-491.62 1558.82,-481.42 1546.08,-473.4"/>
+<polygon fill="black" stroke="black" points="1547.67,-470.27 1537.29,-468.19 1544.1,-476.3 1547.67,-470.27"/>
+</g>
+<!-- 59 -->
+<g id="node51" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="1680.68,-396 1315.68,-396 1315.68,-360 1680.68,-360 1680.68,-396"/>
+<text text-anchor="middle" x="1498.18" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge52" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M1498.18,-431.7C1498.18,-423.98 1498.18,-414.71 1498.18,-406.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-406.1 1498.18,-396.1 1494.68,-406.1 1501.68,-406.1"/>
+</g>
+<!-- 60 -->
+<g id="node52" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="1612.68,-324 1383.68,-324 1383.68,-288 1612.68,-288 1612.68,-324"/>
+<text text-anchor="middle" x="1498.18" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge53" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1498.18,-359.7C1498.18,-351.98 1498.18,-342.71 1498.18,-334.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-334.1 1498.18,-324.1 1494.68,-334.1 1501.68,-334.1"/>
+</g>
+<!-- 61 -->
+<g id="node53" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="1530.68,-252 1465.68,-252 1465.68,-216 1530.68,-216 1530.68,-252"/>
+<text text-anchor="middle" x="1498.18" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 60&#45;&gt;61 -->
+<g id="edge54" class="edge">
+<title>60&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M1498.18,-287.7C1498.18,-279.98 1498.18,-270.71 1498.18,-262.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-262.1 1498.18,-252.1 1494.68,-262.1 1501.68,-262.1"/>
+</g>
+<!-- 62 -->
+<g id="node54" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="1671.68,-180 1324.68,-180 1324.68,-144 1671.68,-144 1671.68,-180"/>
+<text text-anchor="middle" x="1498.18" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge55" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M1498.18,-215.7C1498.18,-207.98 1498.18,-198.71 1498.18,-190.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-190.1 1498.18,-180.1 1494.68,-190.1 1501.68,-190.1"/>
+</g>
+<!-- 63 -->
+<g id="node55" class="node">
+<title>63</title>
+<polygon fill="none" stroke="black" points="1541.18,-108 1455.18,-108 1455.18,-72 1541.18,-72 1541.18,-108"/>
+<text text-anchor="middle" x="1498.18" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 62&#45;&gt;63 -->
+<g id="edge56" class="edge">
+<title>62&#45;&gt;63</title>
+<path fill="none" stroke="black" d="M1498.18,-143.7C1498.18,-135.98 1498.18,-126.71 1498.18,-118.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-118.1 1498.18,-108.1 1494.68,-118.1 1501.68,-118.1"/>
+</g>
+<!-- 64 -->
+<g id="node56" class="node">
+<title>64</title>
+<polygon fill="none" stroke="black" points="1538.18,-36 1458.18,-36 1458.18,0 1538.18,0 1538.18,-36"/>
+<text text-anchor="middle" x="1498.18" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 63&#45;&gt;64 -->
+<g id="edge57" class="edge">
+<title>63&#45;&gt;64</title>
+<path fill="none" stroke="black" d="M1498.18,-71.7C1498.18,-63.98 1498.18,-54.71 1498.18,-46.11"/>
+<polygon fill="black" stroke="black" points="1501.68,-46.1 1498.18,-36.1 1494.68,-46.1 1501.68,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert-tvm_65_2.svg b/images/bert-pytorch/bert-tvm_65_2.svg
new file mode 100644
index 0000000..4b26fbd
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_65_2.svg
@@ -0,0 +1,667 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="2122pt" height="1916pt"
+ viewBox="0.00 0.00 2122.14 1916.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1912)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1912 2118.14,-1912 2118.14,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="1461.64" cy="-1746" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="1461.64" y="-1742.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 16 -->
+<g id="node9" class="node">
+<title>16</title>
+<polygon fill="none" stroke="black" points="1620.64,-1692 1268.64,-1692 1268.64,-1656 1620.64,-1656 1620.64,-1692"/>
+<text text-anchor="middle" x="1444.64" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;16 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;16</title>
+<path fill="none" stroke="black" d="M1457.43,-1727.7C1455.54,-1719.9 1453.26,-1710.51 1451.15,-1701.83"/>
+<polygon fill="black" stroke="black" points="1454.55,-1701 1448.79,-1692.1 1447.75,-1702.65 1454.55,-1701"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1092.64" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="1092.64" y="-1886.3" font-family="Times,serif" font-size="14.00">query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 17 -->
+<g id="node10" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="1189.14,-1836 996.14,-1836 996.14,-1800 1189.14,-1800 1189.14,-1836"/>
+<text text-anchor="middle" x="1092.64" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 1&#45;&gt;17 -->
+<g id="edge2" class="edge">
+<title>1&#45;&gt;17</title>
+<path fill="none" stroke="black" d="M1092.64,-1871.7C1092.64,-1863.98 1092.64,-1854.71 1092.64,-1846.11"/>
+<polygon fill="black" stroke="black" points="1096.14,-1846.1 1092.64,-1836.1 1089.14,-1846.1 1096.14,-1846.1"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="863.64" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="863.64" y="-1526.3" font-family="Times,serif" font-size="14.00">query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 22 -->
+<g id="node15" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="1002.14,-1476 931.14,-1476 931.14,-1440 1002.14,-1440 1002.14,-1476"/>
+<text text-anchor="middle" x="966.64" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 2&#45;&gt;22 -->
+<g id="edge9" class="edge">
+<title>2&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M888.57,-1512.05C901.88,-1503.01 918.46,-1491.74 932.9,-1481.93"/>
+<polygon fill="black" stroke="black" points="935.27,-1484.55 941.58,-1476.03 931.34,-1478.76 935.27,-1484.55"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="564.64" cy="-1890" rx="189.57" ry="18"/>
+<text text-anchor="middle" x="564.64" y="-1886.3" font-family="Times,serif" font-size="14.00">key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 26 -->
+<g id="node19" class="node">
+<title>26</title>
+<polygon fill="none" stroke="black" points="661.14,-1836 468.14,-1836 468.14,-1800 661.14,-1800 661.14,-1836"/>
+<text text-anchor="middle" x="564.64" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 3&#45;&gt;26 -->
+<g id="edge13" class="edge">
+<title>3&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M564.64,-1871.7C564.64,-1863.98 564.64,-1854.71 564.64,-1846.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1846.1 564.64,-1836.1 561.14,-1846.1 568.14,-1846.1"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="156.64" cy="-1530" rx="156.77" ry="18"/>
+<text text-anchor="middle" x="156.64" y="-1526.3" font-family="Times,serif" font-size="14.00">key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 31 -->
+<g id="node24" class="node">
+<title>31</title>
+<polygon fill="none" stroke="black" points="540.14,-1476 469.14,-1476 469.14,-1440 540.14,-1440 540.14,-1476"/>
+<text text-anchor="middle" x="504.64" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 4&#45;&gt;31 -->
+<g id="edge20" class="edge">
+<title>4&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M230.86,-1514.07C299.97,-1500.17 400.8,-1479.89 458.85,-1468.21"/>
+<polygon fill="black" stroke="black" points="459.71,-1471.61 468.83,-1466.2 458.33,-1464.74 459.71,-1471.61"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="1325.64" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1325.64" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 41 -->
+<g id="node33" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1361.14,-828 1290.14,-828 1290.14,-792 1361.14,-792 1361.14,-828"/>
+<text text-anchor="middle" x="1325.64" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;41 -->
+<g id="edge31" class="edge">
+<title>5&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1325.64,-863.7C1325.64,-855.98 1325.64,-846.71 1325.64,-838.11"/>
+<polygon fill="black" stroke="black" points="1329.14,-838.1 1325.64,-828.1 1322.14,-838.1 1329.14,-838.1"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="1885.64" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="1885.64" y="-1886.3" font-family="Times,serif" font-size="14.00">value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 46 -->
+<g id="node38" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="1982.14,-1836 1789.14,-1836 1789.14,-1800 1982.14,-1800 1982.14,-1836"/>
+<text text-anchor="middle" x="1885.64" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;46 -->
+<g id="edge36" class="edge">
+<title>6&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M1885.64,-1871.7C1885.64,-1863.98 1885.64,-1854.71 1885.64,-1846.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1846.1 1885.64,-1836.1 1882.14,-1846.1 1889.14,-1846.1"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="1581.64" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="1581.64" y="-1526.3" font-family="Times,serif" font-size="14.00">value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 51 -->
+<g id="node43" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="1617.14,-1476 1546.14,-1476 1546.14,-1440 1617.14,-1440 1617.14,-1476"/>
+<text text-anchor="middle" x="1581.64" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;51 -->
+<g id="edge43" class="edge">
+<title>7&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1581.64,-1511.7C1581.64,-1503.98 1581.64,-1494.71 1581.64,-1486.11"/>
+<polygon fill="black" stroke="black" points="1585.14,-1486.1 1581.64,-1476.1 1578.14,-1486.1 1585.14,-1486.1"/>
+</g>
+<!-- 20 -->
+<g id="node13" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="1307.14,-1620 1138.14,-1620 1138.14,-1584 1307.14,-1584 1307.14,-1620"/>
+<text text-anchor="middle" x="1222.64" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;20 -->
+<g id="edge5" class="edge">
+<title>16&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M1390.61,-1655.97C1359.19,-1646.06 1319.41,-1633.51 1286.45,-1623.12"/>
+<polygon fill="black" stroke="black" points="1287.49,-1619.78 1276.9,-1620.11 1285.39,-1626.46 1287.49,-1619.78"/>
+</g>
+<!-- 29 -->
+<g id="node22" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="649.14,-1620 480.14,-1620 480.14,-1584 649.14,-1584 649.14,-1620"/>
+<text text-anchor="middle" x="564.64" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;29 -->
+<g id="edge16" class="edge">
+<title>16&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M1268.41,-1658.98C1085.06,-1644.4 804.87,-1622.11 659.51,-1610.55"/>
+<polygon fill="black" stroke="black" points="659.53,-1607.04 649.28,-1609.73 658.97,-1614.02 659.53,-1607.04"/>
+</g>
+<!-- 49 -->
+<g id="node41" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="1970.14,-1620 1801.14,-1620 1801.14,-1584 1970.14,-1584 1970.14,-1620"/>
+<text text-anchor="middle" x="1885.64" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;49 -->
+<g id="edge39" class="edge">
+<title>16&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1551.95,-1655.97C1624.98,-1644.37 1720.72,-1629.18 1790.91,-1618.04"/>
+<polygon fill="black" stroke="black" points="1791.75,-1621.45 1801.08,-1616.42 1790.65,-1614.53 1791.75,-1621.45"/>
+</g>
+<!-- 18 -->
+<g id="node11" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="1273.14,-1764 912.14,-1764 912.14,-1728 1273.14,-1728 1273.14,-1764"/>
+<text text-anchor="middle" x="1092.64" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge3" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M1092.64,-1799.7C1092.64,-1791.98 1092.64,-1782.71 1092.64,-1774.11"/>
+<polygon fill="black" stroke="black" points="1096.14,-1774.1 1092.64,-1764.1 1089.14,-1774.1 1096.14,-1774.1"/>
+</g>
+<!-- 19 -->
+<g id="node12" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="1216.14,-1692 1005.14,-1692 1005.14,-1656 1216.14,-1656 1216.14,-1692"/>
+<text text-anchor="middle" x="1110.64" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 18&#45;&gt;19 -->
+<g id="edge4" class="edge">
+<title>18&#45;&gt;19</title>
+<path fill="none" stroke="black" d="M1097.09,-1727.7C1099.09,-1719.9 1101.51,-1710.51 1103.74,-1701.83"/>
+<polygon fill="black" stroke="black" points="1107.14,-1702.66 1106.24,-1692.1 1100.36,-1700.92 1107.14,-1702.66"/>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge6" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M1138.03,-1655.88C1152.54,-1646.81 1170.55,-1635.55 1186.22,-1625.76"/>
+<polygon fill="black" stroke="black" points="1188.51,-1628.46 1195.13,-1620.19 1184.8,-1622.52 1188.51,-1628.46"/>
+</g>
+<!-- 21 -->
+<g id="node14" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="1396.14,-1548 1049.14,-1548 1049.14,-1512 1396.14,-1512 1396.14,-1548"/>
+<text text-anchor="middle" x="1222.64" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge7" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M1222.64,-1583.7C1222.64,-1575.98 1222.64,-1566.71 1222.64,-1558.11"/>
+<polygon fill="black" stroke="black" points="1226.14,-1558.1 1222.64,-1548.1 1219.14,-1558.1 1226.14,-1558.1"/>
+</g>
+<!-- 21&#45;&gt;22 -->
+<g id="edge8" class="edge">
+<title>21&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M1160.34,-1511.97C1114.53,-1499.44 1053.33,-1482.71 1012.08,-1471.43"/>
+<polygon fill="black" stroke="black" points="1012.87,-1468.01 1002.3,-1468.75 1011.03,-1474.77 1012.87,-1468.01"/>
+</g>
+<!-- 23 -->
+<g id="node16" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="1149.14,-1404 784.14,-1404 784.14,-1368 1149.14,-1368 1149.14,-1404"/>
+<text text-anchor="middle" x="966.64" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge10" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M966.64,-1439.7C966.64,-1431.98 966.64,-1422.71 966.64,-1414.11"/>
+<polygon fill="black" stroke="black" points="970.14,-1414.1 966.64,-1404.1 963.14,-1414.1 970.14,-1414.1"/>
+</g>
+<!-- 24 -->
+<g id="node17" class="node">
+<title>24</title>
+<polygon fill="none" stroke="black" points="1081.14,-1332 852.14,-1332 852.14,-1296 1081.14,-1296 1081.14,-1332"/>
+<text text-anchor="middle" x="966.64" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 23&#45;&gt;24 -->
+<g id="edge11" class="edge">
+<title>23&#45;&gt;24</title>
+<path fill="none" stroke="black" d="M966.64,-1367.7C966.64,-1359.98 966.64,-1350.71 966.64,-1342.11"/>
+<polygon fill="black" stroke="black" points="970.14,-1342.1 966.64,-1332.1 963.14,-1342.1 970.14,-1342.1"/>
+</g>
+<!-- 25 -->
+<g id="node18" class="node">
+<title>25</title>
+<polygon fill="none" stroke="black" points="1138.14,-1116 795.14,-1116 795.14,-1080 1138.14,-1080 1138.14,-1116"/>
+<text text-anchor="middle" x="966.64" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 24&#45;&gt;25 -->
+<g id="edge12" class="edge">
+<title>24&#45;&gt;25</title>
+<path fill="none" stroke="black" d="M966.64,-1295.85C966.64,-1258.83 966.64,-1171.18 966.64,-1126.39"/>
+<polygon fill="black" stroke="black" points="970.14,-1126.23 966.64,-1116.23 963.14,-1126.23 970.14,-1126.23"/>
+</g>
+<!-- 37 -->
+<g id="node30" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1051.14,-1044 882.14,-1044 882.14,-1008 1051.14,-1008 1051.14,-1044"/>
+<text text-anchor="middle" x="966.64" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 25&#45;&gt;37 -->
+<g id="edge26" class="edge">
+<title>25&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M966.64,-1079.7C966.64,-1071.98 966.64,-1062.71 966.64,-1054.11"/>
+<polygon fill="black" stroke="black" points="970.14,-1054.1 966.64,-1044.1 963.14,-1054.1 970.14,-1054.1"/>
+</g>
+<!-- 27 -->
+<g id="node20" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="745.14,-1764 384.14,-1764 384.14,-1728 745.14,-1728 745.14,-1764"/>
+<text text-anchor="middle" x="564.64" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 26&#45;&gt;27 -->
+<g id="edge14" class="edge">
+<title>26&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M564.64,-1799.7C564.64,-1791.98 564.64,-1782.71 564.64,-1774.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1774.1 564.64,-1764.1 561.14,-1774.1 568.14,-1774.1"/>
+</g>
+<!-- 28 -->
+<g id="node21" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="670.14,-1692 459.14,-1692 459.14,-1656 670.14,-1656 670.14,-1692"/>
+<text text-anchor="middle" x="564.64" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 27&#45;&gt;28 -->
+<g id="edge15" class="edge">
+<title>27&#45;&gt;28</title>
+<path fill="none" stroke="black" d="M564.64,-1727.7C564.64,-1719.98 564.64,-1710.71 564.64,-1702.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1702.1 564.64,-1692.1 561.14,-1702.1 568.14,-1702.1"/>
+</g>
+<!-- 28&#45;&gt;29 -->
+<g id="edge17" class="edge">
+<title>28&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M564.64,-1655.7C564.64,-1647.98 564.64,-1638.71 564.64,-1630.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1630.1 564.64,-1620.1 561.14,-1630.1 568.14,-1630.1"/>
+</g>
+<!-- 30 -->
+<g id="node23" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="678.14,-1548 331.14,-1548 331.14,-1512 678.14,-1512 678.14,-1548"/>
+<text text-anchor="middle" x="504.64" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge18" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M549.8,-1583.7C542.61,-1575.3 533.84,-1565.07 525.95,-1555.86"/>
+<polygon fill="black" stroke="black" points="528.46,-1553.42 519.3,-1548.1 523.15,-1557.97 528.46,-1553.42"/>
+</g>
+<!-- 30&#45;&gt;31 -->
+<g id="edge19" class="edge">
+<title>30&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M504.64,-1511.7C504.64,-1503.98 504.64,-1494.71 504.64,-1486.11"/>
+<polygon fill="black" stroke="black" points="508.14,-1486.1 504.64,-1476.1 501.14,-1486.1 508.14,-1486.1"/>
+</g>
+<!-- 32 -->
+<g id="node25" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="766.14,-1404 401.14,-1404 401.14,-1368 766.14,-1368 766.14,-1404"/>
+<text text-anchor="middle" x="583.64" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 31&#45;&gt;32 -->
+<g id="edge21" class="edge">
+<title>31&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M524.16,-1439.7C534.02,-1430.97 546.12,-1420.24 556.83,-1410.75"/>
+<polygon fill="black" stroke="black" points="559.17,-1413.36 564.33,-1404.1 554.53,-1408.12 559.17,-1413.36"/>
+</g>
+<!-- 33 -->
+<g id="node26" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="704.14,-1332 475.14,-1332 475.14,-1296 704.14,-1296 704.14,-1332"/>
+<text text-anchor="middle" x="589.64" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge22" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M585.12,-1367.7C585.78,-1359.98 586.58,-1350.71 587.31,-1342.11"/>
+<polygon fill="black" stroke="black" points="590.8,-1342.37 588.17,-1332.1 583.83,-1341.77 590.8,-1342.37"/>
+</g>
+<!-- 34 -->
+<g id="node27" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="748.14,-1260 519.14,-1260 519.14,-1224 748.14,-1224 748.14,-1260"/>
+<text text-anchor="middle" x="633.64" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge23" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M600.51,-1295.7C605.63,-1287.56 611.83,-1277.69 617.48,-1268.7"/>
+<polygon fill="black" stroke="black" points="620.53,-1270.43 622.88,-1260.1 614.6,-1266.71 620.53,-1270.43"/>
+</g>
+<!-- 35 -->
+<g id="node28" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="827.14,-1188 484.14,-1188 484.14,-1152 827.14,-1152 827.14,-1188"/>
+<text text-anchor="middle" x="655.64" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge24" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M639.07,-1223.7C641.53,-1215.9 644.48,-1206.51 647.2,-1197.83"/>
+<polygon fill="black" stroke="black" points="650.6,-1198.69 650.26,-1188.1 643.92,-1196.59 650.6,-1198.69"/>
+</g>
+<!-- 36 -->
+<g id="node29" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="772.14,-1116 561.14,-1116 561.14,-1080 772.14,-1080 772.14,-1116"/>
+<text text-anchor="middle" x="666.64" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge25" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M658.36,-1151.7C659.57,-1143.98 661.02,-1134.71 662.38,-1126.11"/>
+<polygon fill="black" stroke="black" points="665.85,-1126.53 663.95,-1116.1 658.94,-1125.44 665.85,-1126.53"/>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge27" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M739.64,-1079.97C783.19,-1069.8 838.61,-1056.87 883.79,-1046.33"/>
+<polygon fill="black" stroke="black" points="884.62,-1049.73 893.56,-1044.05 883.03,-1042.91 884.62,-1049.73"/>
+</g>
+<!-- 38 -->
+<g id="node31" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="1189.14,-972 824.14,-972 824.14,-936 1189.14,-936 1189.14,-972"/>
+<text text-anchor="middle" x="1006.64" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 37&#45;&gt;38 -->
+<g id="edge28" class="edge">
+<title>37&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M976.52,-1007.7C981.13,-999.64 986.7,-989.89 991.79,-980.98"/>
+<polygon fill="black" stroke="black" points="994.94,-982.52 996.86,-972.1 988.86,-979.05 994.94,-982.52"/>
+</g>
+<!-- 40 -->
+<g id="node32" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="1080.14,-900 973.14,-900 973.14,-864 1080.14,-864 1080.14,-900"/>
+<text text-anchor="middle" x="1026.64" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 38&#45;&gt;40 -->
+<g id="edge29" class="edge">
+<title>38&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1011.58,-935.7C1013.81,-927.9 1016.49,-918.51 1018.97,-909.83"/>
+<polygon fill="black" stroke="black" points="1022.37,-910.68 1021.75,-900.1 1015.64,-908.76 1022.37,-910.68"/>
+</g>
+<!-- 40&#45;&gt;41 -->
+<g id="edge30" class="edge">
+<title>40&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1080.5,-868.27C1086.61,-866.83 1092.76,-865.38 1098.64,-864 1161.45,-849.25 1234.08,-832.31 1279.92,-821.64"/>
+<polygon fill="black" stroke="black" points="1280.91,-825 1289.85,-819.33 1279.32,-818.18 1280.91,-825"/>
+</g>
+<!-- 42 -->
+<g id="node34" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1420.64,-756 1246.64,-756 1246.64,-720 1420.64,-720 1420.64,-756"/>
+<text text-anchor="middle" x="1333.64" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge32" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1327.61,-791.7C1328.5,-783.98 1329.55,-774.71 1330.54,-766.11"/>
+<polygon fill="black" stroke="black" points="1334.02,-766.44 1331.68,-756.1 1327.07,-765.64 1334.02,-766.44"/>
+</g>
+<!-- 43 -->
+<g id="node35" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1448.14,-684 1265.14,-684 1265.14,-648 1448.14,-648 1448.14,-684"/>
+<text text-anchor="middle" x="1356.64" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge33" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1339.32,-719.7C1341.88,-711.9 1344.97,-702.51 1347.82,-693.83"/>
+<polygon fill="black" stroke="black" points="1351.22,-694.7 1351.02,-684.1 1344.57,-692.51 1351.22,-694.7"/>
+</g>
+<!-- 44 -->
+<g id="node36" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1452.64,-612 1284.64,-612 1284.64,-576 1452.64,-576 1452.64,-612"/>
+<text text-anchor="middle" x="1368.64" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge34" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M1359.6,-647.7C1360.92,-639.98 1362.51,-630.71 1363.99,-622.11"/>
+<polygon fill="black" stroke="black" points="1367.46,-622.55 1365.7,-612.1 1360.56,-621.37 1367.46,-622.55"/>
+</g>
+<!-- 45 -->
+<g id="node37" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="1546.14,-540 1203.14,-540 1203.14,-504 1546.14,-504 1546.14,-540"/>
+<text text-anchor="middle" x="1374.64" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 44&#45;&gt;45 -->
+<g id="edge35" class="edge">
+<title>44&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M1370.12,-575.7C1370.78,-567.98 1371.58,-558.71 1372.31,-550.11"/>
+<polygon fill="black" stroke="black" points="1375.8,-550.37 1373.17,-540.1 1368.83,-549.77 1375.8,-550.37"/>
+</g>
+<!-- 56 -->
+<g id="node48" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="1558.14,-468 1389.14,-468 1389.14,-432 1558.14,-432 1558.14,-468"/>
+<text text-anchor="middle" x="1473.64" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 45&#45;&gt;56 -->
+<g id="edge48" class="edge">
+<title>45&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1398.85,-503.88C1411.56,-494.89 1427.3,-483.76 1441.06,-474.03"/>
+<polygon fill="black" stroke="black" points="1443.18,-476.82 1449.32,-468.19 1439.14,-471.11 1443.18,-476.82"/>
+</g>
+<!-- 47 -->
+<g id="node39" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="2066.14,-1764 1705.14,-1764 1705.14,-1728 2066.14,-1728 2066.14,-1764"/>
+<text text-anchor="middle" x="1885.64" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 46&#45;&gt;47 -->
+<g id="edge37" class="edge">
+<title>46&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1885.64,-1799.7C1885.64,-1791.98 1885.64,-1782.71 1885.64,-1774.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1774.1 1885.64,-1764.1 1882.14,-1774.1 1889.14,-1774.1"/>
+</g>
+<!-- 48 -->
+<g id="node40" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="1991.14,-1692 1780.14,-1692 1780.14,-1656 1991.14,-1656 1991.14,-1692"/>
+<text text-anchor="middle" x="1885.64" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 47&#45;&gt;48 -->
+<g id="edge38" class="edge">
+<title>47&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1885.64,-1727.7C1885.64,-1719.98 1885.64,-1710.71 1885.64,-1702.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1702.1 1885.64,-1692.1 1882.14,-1702.1 1889.14,-1702.1"/>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge40" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1885.64,-1655.7C1885.64,-1647.98 1885.64,-1638.71 1885.64,-1630.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1630.1 1885.64,-1620.1 1882.14,-1630.1 1889.14,-1630.1"/>
+</g>
+<!-- 50 -->
+<g id="node42" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="2114.14,-1548 1767.14,-1548 1767.14,-1512 2114.14,-1512 2114.14,-1548"/>
+<text text-anchor="middle" x="1940.64" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge41" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M1899.23,-1583.7C1905.76,-1575.39 1913.7,-1565.28 1920.88,-1556.14"/>
+<polygon fill="black" stroke="black" points="1923.77,-1558.13 1927.2,-1548.1 1918.27,-1553.81 1923.77,-1558.13"/>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge42" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1853.27,-1511.97C1781.96,-1498.06 1684.02,-1478.96 1627.31,-1467.91"/>
+<polygon fill="black" stroke="black" points="1627.7,-1464.42 1617.22,-1465.94 1626.36,-1471.29 1627.7,-1464.42"/>
+</g>
+<!-- 52 -->
+<g id="node44" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="1764.14,-1404 1399.14,-1404 1399.14,-1368 1764.14,-1368 1764.14,-1404"/>
+<text text-anchor="middle" x="1581.64" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge44" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M1581.64,-1439.7C1581.64,-1431.98 1581.64,-1422.71 1581.64,-1414.11"/>
+<polygon fill="black" stroke="black" points="1585.14,-1414.1 1581.64,-1404.1 1578.14,-1414.1 1585.14,-1414.1"/>
+</g>
+<!-- 53 -->
+<g id="node45" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="1695.14,-1332 1466.14,-1332 1466.14,-1296 1695.14,-1296 1695.14,-1332"/>
+<text text-anchor="middle" x="1580.64" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge45" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1581.39,-1367.7C1581.28,-1359.98 1581.15,-1350.71 1581.02,-1342.11"/>
+<polygon fill="black" stroke="black" points="1584.52,-1342.05 1580.88,-1332.1 1577.52,-1342.15 1584.52,-1342.05"/>
+</g>
+<!-- 54 -->
+<g id="node46" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="1751.14,-1260 1408.14,-1260 1408.14,-1224 1751.14,-1224 1751.14,-1260"/>
+<text text-anchor="middle" x="1579.64" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge46" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M1580.39,-1295.7C1580.28,-1287.98 1580.15,-1278.71 1580.02,-1270.11"/>
+<polygon fill="black" stroke="black" points="1583.52,-1270.05 1579.88,-1260.1 1576.52,-1270.15 1583.52,-1270.05"/>
+</g>
+<!-- 55 -->
+<g id="node47" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="1682.14,-1044 1471.14,-1044 1471.14,-1008 1682.14,-1008 1682.14,-1044"/>
+<text text-anchor="middle" x="1576.64" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge47" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M1579.4,-1223.85C1578.88,-1186.83 1577.65,-1099.18 1577.02,-1054.39"/>
+<polygon fill="black" stroke="black" points="1580.52,-1054.18 1576.88,-1044.23 1573.52,-1054.28 1580.52,-1054.18"/>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge49" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1576.64,-1007.95C1576.64,-981.29 1576.64,-928.11 1576.64,-883 1576.64,-883 1576.64,-883 1576.64,-593 1576.64,-552.36 1579.06,-537.21 1555.64,-504 1546.9,-491.62 1534.27,-481.42 1521.53,-473.4"/>
+<polygon fill="black" stroke="black" points="1523.13,-470.27 1512.74,-468.19 1519.56,-476.3 1523.13,-470.27"/>
+</g>
+<!-- 57 -->
+<g id="node49" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="1656.14,-396 1291.14,-396 1291.14,-360 1656.14,-360 1656.14,-396"/>
+<text text-anchor="middle" x="1473.64" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge50" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M1473.64,-431.7C1473.64,-423.98 1473.64,-414.71 1473.64,-406.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-406.1 1473.64,-396.1 1470.14,-406.1 1477.14,-406.1"/>
+</g>
+<!-- 58 -->
+<g id="node50" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="1588.14,-324 1359.14,-324 1359.14,-288 1588.14,-288 1588.14,-324"/>
+<text text-anchor="middle" x="1473.64" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge51" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1473.64,-359.7C1473.64,-351.98 1473.64,-342.71 1473.64,-334.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-334.1 1473.64,-324.1 1470.14,-334.1 1477.14,-334.1"/>
+</g>
+<!-- 59 -->
+<g id="node51" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="1506.14,-252 1441.14,-252 1441.14,-216 1506.14,-216 1506.14,-252"/>
+<text text-anchor="middle" x="1473.64" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge52" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M1473.64,-287.7C1473.64,-279.98 1473.64,-270.71 1473.64,-262.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-262.1 1473.64,-252.1 1470.14,-262.1 1477.14,-262.1"/>
+</g>
+<!-- 60 -->
+<g id="node52" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="1647.14,-180 1300.14,-180 1300.14,-144 1647.14,-144 1647.14,-180"/>
+<text text-anchor="middle" x="1473.64" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge53" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1473.64,-215.7C1473.64,-207.98 1473.64,-198.71 1473.64,-190.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-190.1 1473.64,-180.1 1470.14,-190.1 1477.14,-190.1"/>
+</g>
+<!-- 61 -->
+<g id="node53" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="1516.64,-108 1430.64,-108 1430.64,-72 1516.64,-72 1516.64,-108"/>
+<text text-anchor="middle" x="1473.64" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 60&#45;&gt;61 -->
+<g id="edge54" class="edge">
+<title>60&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M1473.64,-143.7C1473.64,-135.98 1473.64,-126.71 1473.64,-118.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-118.1 1473.64,-108.1 1470.14,-118.1 1477.14,-118.1"/>
+</g>
+<!-- 62 -->
+<g id="node54" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="1513.64,-36 1433.64,-36 1433.64,0 1513.64,0 1513.64,-36"/>
+<text text-anchor="middle" x="1473.64" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge55" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M1473.64,-71.7C1473.64,-63.98 1473.64,-54.71 1473.64,-46.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-46.1 1473.64,-36.1 1470.14,-46.1 1477.14,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert-tvm_68_0.svg b/images/bert-pytorch/bert-tvm_68_0.svg
new file mode 100644
index 0000000..4b26fbd
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_68_0.svg
@@ -0,0 +1,667 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="2122pt" height="1916pt"
+ viewBox="0.00 0.00 2122.14 1916.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1912)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1912 2118.14,-1912 2118.14,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="1461.64" cy="-1746" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="1461.64" y="-1742.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 16 -->
+<g id="node9" class="node">
+<title>16</title>
+<polygon fill="none" stroke="black" points="1620.64,-1692 1268.64,-1692 1268.64,-1656 1620.64,-1656 1620.64,-1692"/>
+<text text-anchor="middle" x="1444.64" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;16 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;16</title>
+<path fill="none" stroke="black" d="M1457.43,-1727.7C1455.54,-1719.9 1453.26,-1710.51 1451.15,-1701.83"/>
+<polygon fill="black" stroke="black" points="1454.55,-1701 1448.79,-1692.1 1447.75,-1702.65 1454.55,-1701"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1092.64" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="1092.64" y="-1886.3" font-family="Times,serif" font-size="14.00">query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 17 -->
+<g id="node10" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="1189.14,-1836 996.14,-1836 996.14,-1800 1189.14,-1800 1189.14,-1836"/>
+<text text-anchor="middle" x="1092.64" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 1&#45;&gt;17 -->
+<g id="edge2" class="edge">
+<title>1&#45;&gt;17</title>
+<path fill="none" stroke="black" d="M1092.64,-1871.7C1092.64,-1863.98 1092.64,-1854.71 1092.64,-1846.11"/>
+<polygon fill="black" stroke="black" points="1096.14,-1846.1 1092.64,-1836.1 1089.14,-1846.1 1096.14,-1846.1"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="863.64" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="863.64" y="-1526.3" font-family="Times,serif" font-size="14.00">query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 22 -->
+<g id="node15" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="1002.14,-1476 931.14,-1476 931.14,-1440 1002.14,-1440 1002.14,-1476"/>
+<text text-anchor="middle" x="966.64" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 2&#45;&gt;22 -->
+<g id="edge9" class="edge">
+<title>2&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M888.57,-1512.05C901.88,-1503.01 918.46,-1491.74 932.9,-1481.93"/>
+<polygon fill="black" stroke="black" points="935.27,-1484.55 941.58,-1476.03 931.34,-1478.76 935.27,-1484.55"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="564.64" cy="-1890" rx="189.57" ry="18"/>
+<text text-anchor="middle" x="564.64" y="-1886.3" font-family="Times,serif" font-size="14.00">key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 26 -->
+<g id="node19" class="node">
+<title>26</title>
+<polygon fill="none" stroke="black" points="661.14,-1836 468.14,-1836 468.14,-1800 661.14,-1800 661.14,-1836"/>
+<text text-anchor="middle" x="564.64" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 3&#45;&gt;26 -->
+<g id="edge13" class="edge">
+<title>3&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M564.64,-1871.7C564.64,-1863.98 564.64,-1854.71 564.64,-1846.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1846.1 564.64,-1836.1 561.14,-1846.1 568.14,-1846.1"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="156.64" cy="-1530" rx="156.77" ry="18"/>
+<text text-anchor="middle" x="156.64" y="-1526.3" font-family="Times,serif" font-size="14.00">key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 31 -->
+<g id="node24" class="node">
+<title>31</title>
+<polygon fill="none" stroke="black" points="540.14,-1476 469.14,-1476 469.14,-1440 540.14,-1440 540.14,-1476"/>
+<text text-anchor="middle" x="504.64" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 4&#45;&gt;31 -->
+<g id="edge20" class="edge">
+<title>4&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M230.86,-1514.07C299.97,-1500.17 400.8,-1479.89 458.85,-1468.21"/>
+<polygon fill="black" stroke="black" points="459.71,-1471.61 468.83,-1466.2 458.33,-1464.74 459.71,-1471.61"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="1325.64" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1325.64" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 41 -->
+<g id="node33" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1361.14,-828 1290.14,-828 1290.14,-792 1361.14,-792 1361.14,-828"/>
+<text text-anchor="middle" x="1325.64" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;41 -->
+<g id="edge31" class="edge">
+<title>5&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1325.64,-863.7C1325.64,-855.98 1325.64,-846.71 1325.64,-838.11"/>
+<polygon fill="black" stroke="black" points="1329.14,-838.1 1325.64,-828.1 1322.14,-838.1 1329.14,-838.1"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="1885.64" cy="-1890" rx="200.36" ry="18"/>
+<text text-anchor="middle" x="1885.64" y="-1886.3" font-family="Times,serif" font-size="14.00">value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 46 -->
+<g id="node38" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="1982.14,-1836 1789.14,-1836 1789.14,-1800 1982.14,-1800 1982.14,-1836"/>
+<text text-anchor="middle" x="1885.64" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;46 -->
+<g id="edge36" class="edge">
+<title>6&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M1885.64,-1871.7C1885.64,-1863.98 1885.64,-1854.71 1885.64,-1846.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1846.1 1885.64,-1836.1 1882.14,-1846.1 1889.14,-1846.1"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="1581.64" cy="-1530" rx="167.07" ry="18"/>
+<text text-anchor="middle" x="1581.64" y="-1526.3" font-family="Times,serif" font-size="14.00">value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 51 -->
+<g id="node43" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="1617.14,-1476 1546.14,-1476 1546.14,-1440 1617.14,-1440 1617.14,-1476"/>
+<text text-anchor="middle" x="1581.64" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;51 -->
+<g id="edge43" class="edge">
+<title>7&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1581.64,-1511.7C1581.64,-1503.98 1581.64,-1494.71 1581.64,-1486.11"/>
+<polygon fill="black" stroke="black" points="1585.14,-1486.1 1581.64,-1476.1 1578.14,-1486.1 1585.14,-1486.1"/>
+</g>
+<!-- 20 -->
+<g id="node13" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="1307.14,-1620 1138.14,-1620 1138.14,-1584 1307.14,-1584 1307.14,-1620"/>
+<text text-anchor="middle" x="1222.64" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;20 -->
+<g id="edge5" class="edge">
+<title>16&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M1390.61,-1655.97C1359.19,-1646.06 1319.41,-1633.51 1286.45,-1623.12"/>
+<polygon fill="black" stroke="black" points="1287.49,-1619.78 1276.9,-1620.11 1285.39,-1626.46 1287.49,-1619.78"/>
+</g>
+<!-- 29 -->
+<g id="node22" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="649.14,-1620 480.14,-1620 480.14,-1584 649.14,-1584 649.14,-1620"/>
+<text text-anchor="middle" x="564.64" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;29 -->
+<g id="edge16" class="edge">
+<title>16&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M1268.41,-1658.98C1085.06,-1644.4 804.87,-1622.11 659.51,-1610.55"/>
+<polygon fill="black" stroke="black" points="659.53,-1607.04 649.28,-1609.73 658.97,-1614.02 659.53,-1607.04"/>
+</g>
+<!-- 49 -->
+<g id="node41" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="1970.14,-1620 1801.14,-1620 1801.14,-1584 1970.14,-1584 1970.14,-1620"/>
+<text text-anchor="middle" x="1885.64" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;49 -->
+<g id="edge39" class="edge">
+<title>16&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1551.95,-1655.97C1624.98,-1644.37 1720.72,-1629.18 1790.91,-1618.04"/>
+<polygon fill="black" stroke="black" points="1791.75,-1621.45 1801.08,-1616.42 1790.65,-1614.53 1791.75,-1621.45"/>
+</g>
+<!-- 18 -->
+<g id="node11" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="1273.14,-1764 912.14,-1764 912.14,-1728 1273.14,-1728 1273.14,-1764"/>
+<text text-anchor="middle" x="1092.64" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge3" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M1092.64,-1799.7C1092.64,-1791.98 1092.64,-1782.71 1092.64,-1774.11"/>
+<polygon fill="black" stroke="black" points="1096.14,-1774.1 1092.64,-1764.1 1089.14,-1774.1 1096.14,-1774.1"/>
+</g>
+<!-- 19 -->
+<g id="node12" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="1216.14,-1692 1005.14,-1692 1005.14,-1656 1216.14,-1656 1216.14,-1692"/>
+<text text-anchor="middle" x="1110.64" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 18&#45;&gt;19 -->
+<g id="edge4" class="edge">
+<title>18&#45;&gt;19</title>
+<path fill="none" stroke="black" d="M1097.09,-1727.7C1099.09,-1719.9 1101.51,-1710.51 1103.74,-1701.83"/>
+<polygon fill="black" stroke="black" points="1107.14,-1702.66 1106.24,-1692.1 1100.36,-1700.92 1107.14,-1702.66"/>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge6" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M1138.03,-1655.88C1152.54,-1646.81 1170.55,-1635.55 1186.22,-1625.76"/>
+<polygon fill="black" stroke="black" points="1188.51,-1628.46 1195.13,-1620.19 1184.8,-1622.52 1188.51,-1628.46"/>
+</g>
+<!-- 21 -->
+<g id="node14" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="1396.14,-1548 1049.14,-1548 1049.14,-1512 1396.14,-1512 1396.14,-1548"/>
+<text text-anchor="middle" x="1222.64" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge7" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M1222.64,-1583.7C1222.64,-1575.98 1222.64,-1566.71 1222.64,-1558.11"/>
+<polygon fill="black" stroke="black" points="1226.14,-1558.1 1222.64,-1548.1 1219.14,-1558.1 1226.14,-1558.1"/>
+</g>
+<!-- 21&#45;&gt;22 -->
+<g id="edge8" class="edge">
+<title>21&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M1160.34,-1511.97C1114.53,-1499.44 1053.33,-1482.71 1012.08,-1471.43"/>
+<polygon fill="black" stroke="black" points="1012.87,-1468.01 1002.3,-1468.75 1011.03,-1474.77 1012.87,-1468.01"/>
+</g>
+<!-- 23 -->
+<g id="node16" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="1149.14,-1404 784.14,-1404 784.14,-1368 1149.14,-1368 1149.14,-1404"/>
+<text text-anchor="middle" x="966.64" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge10" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M966.64,-1439.7C966.64,-1431.98 966.64,-1422.71 966.64,-1414.11"/>
+<polygon fill="black" stroke="black" points="970.14,-1414.1 966.64,-1404.1 963.14,-1414.1 970.14,-1414.1"/>
+</g>
+<!-- 24 -->
+<g id="node17" class="node">
+<title>24</title>
+<polygon fill="none" stroke="black" points="1081.14,-1332 852.14,-1332 852.14,-1296 1081.14,-1296 1081.14,-1332"/>
+<text text-anchor="middle" x="966.64" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 23&#45;&gt;24 -->
+<g id="edge11" class="edge">
+<title>23&#45;&gt;24</title>
+<path fill="none" stroke="black" d="M966.64,-1367.7C966.64,-1359.98 966.64,-1350.71 966.64,-1342.11"/>
+<polygon fill="black" stroke="black" points="970.14,-1342.1 966.64,-1332.1 963.14,-1342.1 970.14,-1342.1"/>
+</g>
+<!-- 25 -->
+<g id="node18" class="node">
+<title>25</title>
+<polygon fill="none" stroke="black" points="1138.14,-1116 795.14,-1116 795.14,-1080 1138.14,-1080 1138.14,-1116"/>
+<text text-anchor="middle" x="966.64" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 24&#45;&gt;25 -->
+<g id="edge12" class="edge">
+<title>24&#45;&gt;25</title>
+<path fill="none" stroke="black" d="M966.64,-1295.85C966.64,-1258.83 966.64,-1171.18 966.64,-1126.39"/>
+<polygon fill="black" stroke="black" points="970.14,-1126.23 966.64,-1116.23 963.14,-1126.23 970.14,-1126.23"/>
+</g>
+<!-- 37 -->
+<g id="node30" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1051.14,-1044 882.14,-1044 882.14,-1008 1051.14,-1008 1051.14,-1044"/>
+<text text-anchor="middle" x="966.64" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 25&#45;&gt;37 -->
+<g id="edge26" class="edge">
+<title>25&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M966.64,-1079.7C966.64,-1071.98 966.64,-1062.71 966.64,-1054.11"/>
+<polygon fill="black" stroke="black" points="970.14,-1054.1 966.64,-1044.1 963.14,-1054.1 970.14,-1054.1"/>
+</g>
+<!-- 27 -->
+<g id="node20" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="745.14,-1764 384.14,-1764 384.14,-1728 745.14,-1728 745.14,-1764"/>
+<text text-anchor="middle" x="564.64" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 26&#45;&gt;27 -->
+<g id="edge14" class="edge">
+<title>26&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M564.64,-1799.7C564.64,-1791.98 564.64,-1782.71 564.64,-1774.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1774.1 564.64,-1764.1 561.14,-1774.1 568.14,-1774.1"/>
+</g>
+<!-- 28 -->
+<g id="node21" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="670.14,-1692 459.14,-1692 459.14,-1656 670.14,-1656 670.14,-1692"/>
+<text text-anchor="middle" x="564.64" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 27&#45;&gt;28 -->
+<g id="edge15" class="edge">
+<title>27&#45;&gt;28</title>
+<path fill="none" stroke="black" d="M564.64,-1727.7C564.64,-1719.98 564.64,-1710.71 564.64,-1702.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1702.1 564.64,-1692.1 561.14,-1702.1 568.14,-1702.1"/>
+</g>
+<!-- 28&#45;&gt;29 -->
+<g id="edge17" class="edge">
+<title>28&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M564.64,-1655.7C564.64,-1647.98 564.64,-1638.71 564.64,-1630.11"/>
+<polygon fill="black" stroke="black" points="568.14,-1630.1 564.64,-1620.1 561.14,-1630.1 568.14,-1630.1"/>
+</g>
+<!-- 30 -->
+<g id="node23" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="678.14,-1548 331.14,-1548 331.14,-1512 678.14,-1512 678.14,-1548"/>
+<text text-anchor="middle" x="504.64" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge18" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M549.8,-1583.7C542.61,-1575.3 533.84,-1565.07 525.95,-1555.86"/>
+<polygon fill="black" stroke="black" points="528.46,-1553.42 519.3,-1548.1 523.15,-1557.97 528.46,-1553.42"/>
+</g>
+<!-- 30&#45;&gt;31 -->
+<g id="edge19" class="edge">
+<title>30&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M504.64,-1511.7C504.64,-1503.98 504.64,-1494.71 504.64,-1486.11"/>
+<polygon fill="black" stroke="black" points="508.14,-1486.1 504.64,-1476.1 501.14,-1486.1 508.14,-1486.1"/>
+</g>
+<!-- 32 -->
+<g id="node25" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="766.14,-1404 401.14,-1404 401.14,-1368 766.14,-1368 766.14,-1404"/>
+<text text-anchor="middle" x="583.64" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 31&#45;&gt;32 -->
+<g id="edge21" class="edge">
+<title>31&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M524.16,-1439.7C534.02,-1430.97 546.12,-1420.24 556.83,-1410.75"/>
+<polygon fill="black" stroke="black" points="559.17,-1413.36 564.33,-1404.1 554.53,-1408.12 559.17,-1413.36"/>
+</g>
+<!-- 33 -->
+<g id="node26" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="704.14,-1332 475.14,-1332 475.14,-1296 704.14,-1296 704.14,-1332"/>
+<text text-anchor="middle" x="589.64" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge22" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M585.12,-1367.7C585.78,-1359.98 586.58,-1350.71 587.31,-1342.11"/>
+<polygon fill="black" stroke="black" points="590.8,-1342.37 588.17,-1332.1 583.83,-1341.77 590.8,-1342.37"/>
+</g>
+<!-- 34 -->
+<g id="node27" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="748.14,-1260 519.14,-1260 519.14,-1224 748.14,-1224 748.14,-1260"/>
+<text text-anchor="middle" x="633.64" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge23" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M600.51,-1295.7C605.63,-1287.56 611.83,-1277.69 617.48,-1268.7"/>
+<polygon fill="black" stroke="black" points="620.53,-1270.43 622.88,-1260.1 614.6,-1266.71 620.53,-1270.43"/>
+</g>
+<!-- 35 -->
+<g id="node28" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="827.14,-1188 484.14,-1188 484.14,-1152 827.14,-1152 827.14,-1188"/>
+<text text-anchor="middle" x="655.64" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge24" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M639.07,-1223.7C641.53,-1215.9 644.48,-1206.51 647.2,-1197.83"/>
+<polygon fill="black" stroke="black" points="650.6,-1198.69 650.26,-1188.1 643.92,-1196.59 650.6,-1198.69"/>
+</g>
+<!-- 36 -->
+<g id="node29" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="772.14,-1116 561.14,-1116 561.14,-1080 772.14,-1080 772.14,-1116"/>
+<text text-anchor="middle" x="666.64" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge25" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M658.36,-1151.7C659.57,-1143.98 661.02,-1134.71 662.38,-1126.11"/>
+<polygon fill="black" stroke="black" points="665.85,-1126.53 663.95,-1116.1 658.94,-1125.44 665.85,-1126.53"/>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge27" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M739.64,-1079.97C783.19,-1069.8 838.61,-1056.87 883.79,-1046.33"/>
+<polygon fill="black" stroke="black" points="884.62,-1049.73 893.56,-1044.05 883.03,-1042.91 884.62,-1049.73"/>
+</g>
+<!-- 38 -->
+<g id="node31" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="1189.14,-972 824.14,-972 824.14,-936 1189.14,-936 1189.14,-972"/>
+<text text-anchor="middle" x="1006.64" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 37&#45;&gt;38 -->
+<g id="edge28" class="edge">
+<title>37&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M976.52,-1007.7C981.13,-999.64 986.7,-989.89 991.79,-980.98"/>
+<polygon fill="black" stroke="black" points="994.94,-982.52 996.86,-972.1 988.86,-979.05 994.94,-982.52"/>
+</g>
+<!-- 40 -->
+<g id="node32" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="1080.14,-900 973.14,-900 973.14,-864 1080.14,-864 1080.14,-900"/>
+<text text-anchor="middle" x="1026.64" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 38&#45;&gt;40 -->
+<g id="edge29" class="edge">
+<title>38&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1011.58,-935.7C1013.81,-927.9 1016.49,-918.51 1018.97,-909.83"/>
+<polygon fill="black" stroke="black" points="1022.37,-910.68 1021.75,-900.1 1015.64,-908.76 1022.37,-910.68"/>
+</g>
+<!-- 40&#45;&gt;41 -->
+<g id="edge30" class="edge">
+<title>40&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1080.5,-868.27C1086.61,-866.83 1092.76,-865.38 1098.64,-864 1161.45,-849.25 1234.08,-832.31 1279.92,-821.64"/>
+<polygon fill="black" stroke="black" points="1280.91,-825 1289.85,-819.33 1279.32,-818.18 1280.91,-825"/>
+</g>
+<!-- 42 -->
+<g id="node34" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1420.64,-756 1246.64,-756 1246.64,-720 1420.64,-720 1420.64,-756"/>
+<text text-anchor="middle" x="1333.64" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge32" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1327.61,-791.7C1328.5,-783.98 1329.55,-774.71 1330.54,-766.11"/>
+<polygon fill="black" stroke="black" points="1334.02,-766.44 1331.68,-756.1 1327.07,-765.64 1334.02,-766.44"/>
+</g>
+<!-- 43 -->
+<g id="node35" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1448.14,-684 1265.14,-684 1265.14,-648 1448.14,-648 1448.14,-684"/>
+<text text-anchor="middle" x="1356.64" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge33" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1339.32,-719.7C1341.88,-711.9 1344.97,-702.51 1347.82,-693.83"/>
+<polygon fill="black" stroke="black" points="1351.22,-694.7 1351.02,-684.1 1344.57,-692.51 1351.22,-694.7"/>
+</g>
+<!-- 44 -->
+<g id="node36" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1452.64,-612 1284.64,-612 1284.64,-576 1452.64,-576 1452.64,-612"/>
+<text text-anchor="middle" x="1368.64" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge34" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M1359.6,-647.7C1360.92,-639.98 1362.51,-630.71 1363.99,-622.11"/>
+<polygon fill="black" stroke="black" points="1367.46,-622.55 1365.7,-612.1 1360.56,-621.37 1367.46,-622.55"/>
+</g>
+<!-- 45 -->
+<g id="node37" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="1546.14,-540 1203.14,-540 1203.14,-504 1546.14,-504 1546.14,-540"/>
+<text text-anchor="middle" x="1374.64" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 44&#45;&gt;45 -->
+<g id="edge35" class="edge">
+<title>44&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M1370.12,-575.7C1370.78,-567.98 1371.58,-558.71 1372.31,-550.11"/>
+<polygon fill="black" stroke="black" points="1375.8,-550.37 1373.17,-540.1 1368.83,-549.77 1375.8,-550.37"/>
+</g>
+<!-- 56 -->
+<g id="node48" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="1558.14,-468 1389.14,-468 1389.14,-432 1558.14,-432 1558.14,-468"/>
+<text text-anchor="middle" x="1473.64" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 45&#45;&gt;56 -->
+<g id="edge48" class="edge">
+<title>45&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1398.85,-503.88C1411.56,-494.89 1427.3,-483.76 1441.06,-474.03"/>
+<polygon fill="black" stroke="black" points="1443.18,-476.82 1449.32,-468.19 1439.14,-471.11 1443.18,-476.82"/>
+</g>
+<!-- 47 -->
+<g id="node39" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="2066.14,-1764 1705.14,-1764 1705.14,-1728 2066.14,-1728 2066.14,-1764"/>
+<text text-anchor="middle" x="1885.64" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 46&#45;&gt;47 -->
+<g id="edge37" class="edge">
+<title>46&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1885.64,-1799.7C1885.64,-1791.98 1885.64,-1782.71 1885.64,-1774.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1774.1 1885.64,-1764.1 1882.14,-1774.1 1889.14,-1774.1"/>
+</g>
+<!-- 48 -->
+<g id="node40" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="1991.14,-1692 1780.14,-1692 1780.14,-1656 1991.14,-1656 1991.14,-1692"/>
+<text text-anchor="middle" x="1885.64" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 47&#45;&gt;48 -->
+<g id="edge38" class="edge">
+<title>47&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1885.64,-1727.7C1885.64,-1719.98 1885.64,-1710.71 1885.64,-1702.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1702.1 1885.64,-1692.1 1882.14,-1702.1 1889.14,-1702.1"/>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge40" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1885.64,-1655.7C1885.64,-1647.98 1885.64,-1638.71 1885.64,-1630.11"/>
+<polygon fill="black" stroke="black" points="1889.14,-1630.1 1885.64,-1620.1 1882.14,-1630.1 1889.14,-1630.1"/>
+</g>
+<!-- 50 -->
+<g id="node42" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="2114.14,-1548 1767.14,-1548 1767.14,-1512 2114.14,-1512 2114.14,-1548"/>
+<text text-anchor="middle" x="1940.64" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge41" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M1899.23,-1583.7C1905.76,-1575.39 1913.7,-1565.28 1920.88,-1556.14"/>
+<polygon fill="black" stroke="black" points="1923.77,-1558.13 1927.2,-1548.1 1918.27,-1553.81 1923.77,-1558.13"/>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge42" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1853.27,-1511.97C1781.96,-1498.06 1684.02,-1478.96 1627.31,-1467.91"/>
+<polygon fill="black" stroke="black" points="1627.7,-1464.42 1617.22,-1465.94 1626.36,-1471.29 1627.7,-1464.42"/>
+</g>
+<!-- 52 -->
+<g id="node44" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="1764.14,-1404 1399.14,-1404 1399.14,-1368 1764.14,-1368 1764.14,-1404"/>
+<text text-anchor="middle" x="1581.64" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge44" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M1581.64,-1439.7C1581.64,-1431.98 1581.64,-1422.71 1581.64,-1414.11"/>
+<polygon fill="black" stroke="black" points="1585.14,-1414.1 1581.64,-1404.1 1578.14,-1414.1 1585.14,-1414.1"/>
+</g>
+<!-- 53 -->
+<g id="node45" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="1695.14,-1332 1466.14,-1332 1466.14,-1296 1695.14,-1296 1695.14,-1332"/>
+<text text-anchor="middle" x="1580.64" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge45" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1581.39,-1367.7C1581.28,-1359.98 1581.15,-1350.71 1581.02,-1342.11"/>
+<polygon fill="black" stroke="black" points="1584.52,-1342.05 1580.88,-1332.1 1577.52,-1342.15 1584.52,-1342.05"/>
+</g>
+<!-- 54 -->
+<g id="node46" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="1751.14,-1260 1408.14,-1260 1408.14,-1224 1751.14,-1224 1751.14,-1260"/>
+<text text-anchor="middle" x="1579.64" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge46" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M1580.39,-1295.7C1580.28,-1287.98 1580.15,-1278.71 1580.02,-1270.11"/>
+<polygon fill="black" stroke="black" points="1583.52,-1270.05 1579.88,-1260.1 1576.52,-1270.15 1583.52,-1270.05"/>
+</g>
+<!-- 55 -->
+<g id="node47" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="1682.14,-1044 1471.14,-1044 1471.14,-1008 1682.14,-1008 1682.14,-1044"/>
+<text text-anchor="middle" x="1576.64" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge47" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M1579.4,-1223.85C1578.88,-1186.83 1577.65,-1099.18 1577.02,-1054.39"/>
+<polygon fill="black" stroke="black" points="1580.52,-1054.18 1576.88,-1044.23 1573.52,-1054.28 1580.52,-1054.18"/>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge49" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1576.64,-1007.95C1576.64,-981.29 1576.64,-928.11 1576.64,-883 1576.64,-883 1576.64,-883 1576.64,-593 1576.64,-552.36 1579.06,-537.21 1555.64,-504 1546.9,-491.62 1534.27,-481.42 1521.53,-473.4"/>
+<polygon fill="black" stroke="black" points="1523.13,-470.27 1512.74,-468.19 1519.56,-476.3 1523.13,-470.27"/>
+</g>
+<!-- 57 -->
+<g id="node49" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="1656.14,-396 1291.14,-396 1291.14,-360 1656.14,-360 1656.14,-396"/>
+<text text-anchor="middle" x="1473.64" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge50" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M1473.64,-431.7C1473.64,-423.98 1473.64,-414.71 1473.64,-406.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-406.1 1473.64,-396.1 1470.14,-406.1 1477.14,-406.1"/>
+</g>
+<!-- 58 -->
+<g id="node50" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="1588.14,-324 1359.14,-324 1359.14,-288 1588.14,-288 1588.14,-324"/>
+<text text-anchor="middle" x="1473.64" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge51" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1473.64,-359.7C1473.64,-351.98 1473.64,-342.71 1473.64,-334.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-334.1 1473.64,-324.1 1470.14,-334.1 1477.14,-334.1"/>
+</g>
+<!-- 59 -->
+<g id="node51" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="1506.14,-252 1441.14,-252 1441.14,-216 1506.14,-216 1506.14,-252"/>
+<text text-anchor="middle" x="1473.64" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge52" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M1473.64,-287.7C1473.64,-279.98 1473.64,-270.71 1473.64,-262.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-262.1 1473.64,-252.1 1470.14,-262.1 1477.14,-262.1"/>
+</g>
+<!-- 60 -->
+<g id="node52" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="1647.14,-180 1300.14,-180 1300.14,-144 1647.14,-144 1647.14,-180"/>
+<text text-anchor="middle" x="1473.64" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge53" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1473.64,-215.7C1473.64,-207.98 1473.64,-198.71 1473.64,-190.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-190.1 1473.64,-180.1 1470.14,-190.1 1477.14,-190.1"/>
+</g>
+<!-- 61 -->
+<g id="node53" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="1516.64,-108 1430.64,-108 1430.64,-72 1516.64,-72 1516.64,-108"/>
+<text text-anchor="middle" x="1473.64" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 60&#45;&gt;61 -->
+<g id="edge54" class="edge">
+<title>60&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M1473.64,-143.7C1473.64,-135.98 1473.64,-126.71 1473.64,-118.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-118.1 1473.64,-108.1 1470.14,-118.1 1477.14,-118.1"/>
+</g>
+<!-- 62 -->
+<g id="node54" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="1513.64,-36 1433.64,-36 1433.64,0 1513.64,0 1513.64,-36"/>
+<text text-anchor="middle" x="1473.64" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge55" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M1473.64,-71.7C1473.64,-63.98 1473.64,-54.71 1473.64,-46.11"/>
+<polygon fill="black" stroke="black" points="1477.14,-46.1 1473.64,-36.1 1470.14,-46.1 1477.14,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert-tvm_70_0.svg b/images/bert-pytorch/bert-tvm_70_0.svg
new file mode 100644
index 0000000..f015c0b
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_70_0.svg
@@ -0,0 +1,667 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="1718pt" height="1916pt"
+ viewBox="0.00 0.00 1717.50 1916.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1912)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1912 1713.5,-1912 1713.5,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="947.5" cy="-1746" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="947.5" y="-1742.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 10 -->
+<g id="node3" class="node">
+<title>10</title>
+<polygon fill="none" stroke="black" points="1089.5,-1692 737.5,-1692 737.5,-1656 1089.5,-1656 1089.5,-1692"/>
+<text text-anchor="middle" x="913.5" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;10 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;10</title>
+<path fill="none" stroke="black" d="M939.1,-1727.7C935.23,-1719.73 930.55,-1710.1 926.26,-1701.26"/>
+<polygon fill="black" stroke="black" points="929.33,-1699.57 921.81,-1692.1 923.03,-1702.63 929.33,-1699.57"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1084.5" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1084.5" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 39 -->
+<g id="node31" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="1120,-828 1049,-828 1049,-792 1120,-792 1120,-828"/>
+<text text-anchor="middle" x="1084.5" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 1&#45;&gt;39 -->
+<g id="edge31" class="edge">
+<title>1&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M1084.5,-863.7C1084.5,-855.98 1084.5,-846.71 1084.5,-838.11"/>
+<polygon fill="black" stroke="black" points="1088,-838.1 1084.5,-828.1 1081,-838.1 1088,-838.1"/>
+</g>
+<!-- 15 -->
+<g id="node8" class="node">
+<title>15</title>
+<polygon fill="none" stroke="black" points="834,-1620 665,-1620 665,-1584 834,-1584 834,-1620"/>
+<text text-anchor="middle" x="749.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;15 -->
+<g id="edge5" class="edge">
+<title>10&#45;&gt;15</title>
+<path fill="none" stroke="black" d="M873.38,-1655.88C850.86,-1646.26 822.58,-1634.19 798.73,-1624.01"/>
+<polygon fill="black" stroke="black" points="800.06,-1620.77 789.49,-1620.07 797.31,-1627.21 800.06,-1620.77"/>
+</g>
+<!-- 26 -->
+<g id="node19" class="node">
+<title>26</title>
+<polygon fill="none" stroke="black" points="275,-1620 106,-1620 106,-1584 275,-1584 275,-1620"/>
+<text text-anchor="middle" x="190.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;26 -->
+<g id="edge16" class="edge">
+<title>10&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M737.56,-1655.97C595.14,-1642.18 400.01,-1623.28 285.39,-1612.19"/>
+<polygon fill="black" stroke="black" points="285.53,-1608.68 275.24,-1611.2 284.86,-1615.65 285.53,-1608.68"/>
+</g>
+<!-- 48 -->
+<g id="node40" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="1405,-1620 1236,-1620 1236,-1584 1405,-1584 1405,-1620"/>
+<text text-anchor="middle" x="1320.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;48 -->
+<g id="edge39" class="edge">
+<title>10&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1012.54,-1655.97C1077.43,-1644.81 1161.75,-1630.3 1225.75,-1619.3"/>
+<polygon fill="black" stroke="black" points="1226.58,-1622.71 1235.84,-1617.56 1225.39,-1615.81 1226.58,-1622.71"/>
+</g>
+<!-- 11 -->
+<g id="node4" class="node">
+<title>11</title>
+<polygon fill="none" stroke="black" points="690.5,-1908 466.5,-1908 466.5,-1872 690.5,-1872 690.5,-1908"/>
+<text text-anchor="middle" x="578.5" y="-1886.3" font-family="Times,serif" font-size="14.00">Constant((768, 768), float32)</text>
+</g>
+<!-- 12 -->
+<g id="node5" class="node">
+<title>12</title>
+<polygon fill="none" stroke="black" points="675,-1836 482,-1836 482,-1800 675,-1800 675,-1836"/>
+<text text-anchor="middle" x="578.5" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 11&#45;&gt;12 -->
+<g id="edge2" class="edge">
+<title>11&#45;&gt;12</title>
+<path fill="none" stroke="black" d="M578.5,-1871.7C578.5,-1863.98 578.5,-1854.71 578.5,-1846.11"/>
+<polygon fill="black" stroke="black" points="582,-1846.1 578.5,-1836.1 575,-1846.1 582,-1846.1"/>
+</g>
+<!-- 13 -->
+<g id="node6" class="node">
+<title>13</title>
+<polygon fill="none" stroke="black" points="759,-1764 398,-1764 398,-1728 759,-1728 759,-1764"/>
+<text text-anchor="middle" x="578.5" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 12&#45;&gt;13 -->
+<g id="edge3" class="edge">
+<title>12&#45;&gt;13</title>
+<path fill="none" stroke="black" d="M578.5,-1799.7C578.5,-1791.98 578.5,-1782.71 578.5,-1774.11"/>
+<polygon fill="black" stroke="black" points="582,-1774.1 578.5,-1764.1 575,-1774.1 582,-1774.1"/>
+</g>
+<!-- 14 -->
+<g id="node7" class="node">
+<title>14</title>
+<polygon fill="none" stroke="black" points="701,-1692 490,-1692 490,-1656 701,-1656 701,-1692"/>
+<text text-anchor="middle" x="595.5" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 13&#45;&gt;14 -->
+<g id="edge4" class="edge">
+<title>13&#45;&gt;14</title>
+<path fill="none" stroke="black" d="M582.7,-1727.7C584.6,-1719.9 586.88,-1710.51 588.98,-1701.83"/>
+<polygon fill="black" stroke="black" points="592.39,-1702.65 591.35,-1692.1 585.58,-1701 592.39,-1702.65"/>
+</g>
+<!-- 14&#45;&gt;15 -->
+<g id="edge6" class="edge">
+<title>14&#45;&gt;15</title>
+<path fill="none" stroke="black" d="M633.17,-1655.88C654.13,-1646.35 680.4,-1634.41 702.68,-1624.28"/>
+<polygon fill="black" stroke="black" points="704.29,-1627.39 711.95,-1620.07 701.4,-1621.02 704.29,-1627.39"/>
+</g>
+<!-- 16 -->
+<g id="node9" class="node">
+<title>16</title>
+<polygon fill="none" stroke="black" points="923,-1548 576,-1548 576,-1512 923,-1512 923,-1548"/>
+<text text-anchor="middle" x="749.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 15&#45;&gt;16 -->
+<g id="edge7" class="edge">
+<title>15&#45;&gt;16</title>
+<path fill="none" stroke="black" d="M749.5,-1583.7C749.5,-1575.98 749.5,-1566.71 749.5,-1558.11"/>
+<polygon fill="black" stroke="black" points="753,-1558.1 749.5,-1548.1 746,-1558.1 753,-1558.1"/>
+</g>
+<!-- 18 -->
+<g id="node11" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="869,-1476 798,-1476 798,-1440 869,-1440 869,-1476"/>
+<text text-anchor="middle" x="833.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;18 -->
+<g id="edge8" class="edge">
+<title>16&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M770.26,-1511.7C780.74,-1502.97 793.61,-1492.24 805,-1482.75"/>
+<polygon fill="black" stroke="black" points="807.53,-1485.19 812.97,-1476.1 803.05,-1479.82 807.53,-1485.19"/>
+</g>
+<!-- 17 -->
+<g id="node10" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="1133.5,-1548 941.5,-1548 941.5,-1512 1133.5,-1512 1133.5,-1548"/>
+<text text-anchor="middle" x="1037.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge9" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M987.86,-1511.97C954.62,-1500.56 911.2,-1485.66 878.88,-1474.57"/>
+<polygon fill="black" stroke="black" points="879.83,-1471.2 869.23,-1471.26 877.55,-1477.82 879.83,-1471.2"/>
+</g>
+<!-- 19 -->
+<g id="node12" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="1016,-1404 651,-1404 651,-1368 1016,-1368 1016,-1404"/>
+<text text-anchor="middle" x="833.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 18&#45;&gt;19 -->
+<g id="edge10" class="edge">
+<title>18&#45;&gt;19</title>
+<path fill="none" stroke="black" d="M833.5,-1439.7C833.5,-1431.98 833.5,-1422.71 833.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="837,-1414.1 833.5,-1404.1 830,-1414.1 837,-1414.1"/>
+</g>
+<!-- 20 -->
+<g id="node13" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="921,-1332 692,-1332 692,-1296 921,-1296 921,-1332"/>
+<text text-anchor="middle" x="806.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge11" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M826.83,-1367.7C823.79,-1359.81 820.12,-1350.3 816.74,-1341.55"/>
+<polygon fill="black" stroke="black" points="819.96,-1340.17 813.1,-1332.1 813.43,-1342.69 819.96,-1340.17"/>
+</g>
+<!-- 21 -->
+<g id="node14" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="973,-1116 630,-1116 630,-1080 973,-1080 973,-1116"/>
+<text text-anchor="middle" x="801.5" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge12" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M806.1,-1295.85C805.23,-1258.83 803.19,-1171.18 802.14,-1126.39"/>
+<polygon fill="black" stroke="black" points="805.64,-1126.15 801.9,-1116.23 798.64,-1126.31 805.64,-1126.15"/>
+</g>
+<!-- 35 -->
+<g id="node28" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="880,-1044 711,-1044 711,-1008 880,-1008 880,-1044"/>
+<text text-anchor="middle" x="795.5" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 21&#45;&gt;35 -->
+<g id="edge26" class="edge">
+<title>21&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M800.02,-1079.7C799.36,-1071.98 798.56,-1062.71 797.82,-1054.11"/>
+<polygon fill="black" stroke="black" points="801.31,-1053.77 796.97,-1044.1 794.33,-1054.37 801.31,-1053.77"/>
+</g>
+<!-- 22 -->
+<g id="node15" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="302.5,-1908 78.5,-1908 78.5,-1872 302.5,-1872 302.5,-1908"/>
+<text text-anchor="middle" x="190.5" y="-1886.3" font-family="Times,serif" font-size="14.00">Constant((768, 768), float32)</text>
+</g>
+<!-- 23 -->
+<g id="node16" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="287,-1836 94,-1836 94,-1800 287,-1800 287,-1836"/>
+<text text-anchor="middle" x="190.5" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge13" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M190.5,-1871.7C190.5,-1863.98 190.5,-1854.71 190.5,-1846.11"/>
+<polygon fill="black" stroke="black" points="194,-1846.1 190.5,-1836.1 187,-1846.1 194,-1846.1"/>
+</g>
+<!-- 24 -->
+<g id="node17" class="node">
+<title>24</title>
+<polygon fill="none" stroke="black" points="371,-1764 10,-1764 10,-1728 371,-1728 371,-1764"/>
+<text text-anchor="middle" x="190.5" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 23&#45;&gt;24 -->
+<g id="edge14" class="edge">
+<title>23&#45;&gt;24</title>
+<path fill="none" stroke="black" d="M190.5,-1799.7C190.5,-1791.98 190.5,-1782.71 190.5,-1774.11"/>
+<polygon fill="black" stroke="black" points="194,-1774.1 190.5,-1764.1 187,-1774.1 194,-1774.1"/>
+</g>
+<!-- 25 -->
+<g id="node18" class="node">
+<title>25</title>
+<polygon fill="none" stroke="black" points="296,-1692 85,-1692 85,-1656 296,-1656 296,-1692"/>
+<text text-anchor="middle" x="190.5" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 24&#45;&gt;25 -->
+<g id="edge15" class="edge">
+<title>24&#45;&gt;25</title>
+<path fill="none" stroke="black" d="M190.5,-1727.7C190.5,-1719.98 190.5,-1710.71 190.5,-1702.11"/>
+<polygon fill="black" stroke="black" points="194,-1702.1 190.5,-1692.1 187,-1702.1 194,-1702.1"/>
+</g>
+<!-- 25&#45;&gt;26 -->
+<g id="edge17" class="edge">
+<title>25&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M190.5,-1655.7C190.5,-1647.98 190.5,-1638.71 190.5,-1630.11"/>
+<polygon fill="black" stroke="black" points="194,-1630.1 190.5,-1620.1 187,-1630.1 194,-1630.1"/>
+</g>
+<!-- 27 -->
+<g id="node20" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="347,-1548 0,-1548 0,-1512 347,-1512 347,-1548"/>
+<text text-anchor="middle" x="173.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 26&#45;&gt;27 -->
+<g id="edge18" class="edge">
+<title>26&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M186.3,-1583.7C184.4,-1575.9 182.12,-1566.51 180.02,-1557.83"/>
+<polygon fill="black" stroke="black" points="183.42,-1557 177.65,-1548.1 176.61,-1558.65 183.42,-1557"/>
+</g>
+<!-- 29 -->
+<g id="node22" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="486,-1476 415,-1476 415,-1440 486,-1440 486,-1476"/>
+<text text-anchor="middle" x="450.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 27&#45;&gt;29 -->
+<g id="edge19" class="edge">
+<title>27&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M240.91,-1511.97C291.72,-1499.13 360.04,-1481.86 404.66,-1470.58"/>
+<polygon fill="black" stroke="black" points="405.8,-1473.91 414.64,-1468.06 404.09,-1467.12 405.8,-1473.91"/>
+</g>
+<!-- 28 -->
+<g id="node21" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="557.5,-1548 365.5,-1548 365.5,-1512 557.5,-1512 557.5,-1548"/>
+<text text-anchor="middle" x="461.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 28&#45;&gt;29 -->
+<g id="edge20" class="edge">
+<title>28&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M458.78,-1511.7C457.57,-1503.98 456.11,-1494.71 454.76,-1486.11"/>
+<polygon fill="black" stroke="black" points="458.2,-1485.44 453.19,-1476.1 451.28,-1486.53 458.2,-1485.44"/>
+</g>
+<!-- 30 -->
+<g id="node23" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="633,-1404 268,-1404 268,-1368 633,-1368 633,-1404"/>
+<text text-anchor="middle" x="450.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge21" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M450.5,-1439.7C450.5,-1431.98 450.5,-1422.71 450.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="454,-1414.1 450.5,-1404.1 447,-1414.1 454,-1414.1"/>
+</g>
+<!-- 31 -->
+<g id="node24" class="node">
+<title>31</title>
+<polygon fill="none" stroke="black" points="566,-1332 337,-1332 337,-1296 566,-1296 566,-1332"/>
+<text text-anchor="middle" x="451.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 30&#45;&gt;31 -->
+<g id="edge22" class="edge">
+<title>30&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M450.75,-1367.7C450.86,-1359.98 450.99,-1350.71 451.11,-1342.11"/>
+<polygon fill="black" stroke="black" points="454.61,-1342.15 451.26,-1332.1 447.61,-1342.05 454.61,-1342.15"/>
+</g>
+<!-- 32 -->
+<g id="node25" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="596,-1260 367,-1260 367,-1224 596,-1224 596,-1260"/>
+<text text-anchor="middle" x="481.5" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 31&#45;&gt;32 -->
+<g id="edge23" class="edge">
+<title>31&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M458.92,-1295.7C462.29,-1287.81 466.37,-1278.3 470.12,-1269.55"/>
+<polygon fill="black" stroke="black" points="473.45,-1270.67 474.17,-1260.1 467.01,-1267.92 473.45,-1270.67"/>
+</g>
+<!-- 33 -->
+<g id="node26" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="668,-1188 325,-1188 325,-1152 668,-1152 668,-1188"/>
+<text text-anchor="middle" x="496.5" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge24" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M485.21,-1223.7C486.86,-1215.98 488.85,-1206.71 490.69,-1198.11"/>
+<polygon fill="black" stroke="black" points="494.16,-1198.62 492.83,-1188.1 487.32,-1197.15 494.16,-1198.62"/>
+</g>
+<!-- 34 -->
+<g id="node27" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="610,-1116 399,-1116 399,-1080 610,-1080 610,-1116"/>
+<text text-anchor="middle" x="504.5" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge25" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M498.48,-1151.7C499.36,-1143.98 500.42,-1134.71 501.4,-1126.11"/>
+<polygon fill="black" stroke="black" points="504.89,-1126.44 502.55,-1116.1 497.93,-1125.64 504.89,-1126.44"/>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge27" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M575.31,-1079.97C617.47,-1069.83 671.09,-1056.93 714.87,-1046.4"/>
+<polygon fill="black" stroke="black" points="715.72,-1049.79 724.62,-1044.05 714.08,-1042.99 715.72,-1049.79"/>
+</g>
+<!-- 36 -->
+<g id="node29" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="978,-972 613,-972 613,-936 978,-936 978,-972"/>
+<text text-anchor="middle" x="795.5" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge28" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M795.5,-1007.7C795.5,-999.98 795.5,-990.71 795.5,-982.11"/>
+<polygon fill="black" stroke="black" points="799,-982.1 795.5,-972.1 792,-982.1 799,-982.1"/>
+</g>
+<!-- 38 -->
+<g id="node30" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="849,-900 742,-900 742,-864 849,-864 849,-900"/>
+<text text-anchor="middle" x="795.5" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 36&#45;&gt;38 -->
+<g id="edge29" class="edge">
+<title>36&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M795.5,-935.7C795.5,-927.98 795.5,-918.71 795.5,-910.11"/>
+<polygon fill="black" stroke="black" points="799,-910.1 795.5,-900.1 792,-910.1 799,-910.1"/>
+</g>
+<!-- 38&#45;&gt;39 -->
+<g id="edge30" class="edge">
+<title>38&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M849.24,-866.13C852.03,-865.4 854.8,-864.69 857.5,-864 920.04,-848.13 992.73,-831.47 1038.67,-821.16"/>
+<polygon fill="black" stroke="black" points="1039.63,-824.53 1048.62,-818.93 1038.1,-817.7 1039.63,-824.53"/>
+</g>
+<!-- 40 -->
+<g id="node32" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="1194.5,-756 1020.5,-756 1020.5,-720 1194.5,-720 1194.5,-756"/>
+<text text-anchor="middle" x="1107.5" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 39&#45;&gt;40 -->
+<g id="edge32" class="edge">
+<title>39&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1090.19,-791.7C1092.75,-783.9 1095.83,-774.51 1098.68,-765.83"/>
+<polygon fill="black" stroke="black" points="1102.08,-766.7 1101.88,-756.1 1095.43,-764.51 1102.08,-766.7"/>
+</g>
+<!-- 41 -->
+<g id="node33" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1211,-684 1028,-684 1028,-648 1211,-648 1211,-684"/>
+<text text-anchor="middle" x="1119.5" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 40&#45;&gt;41 -->
+<g id="edge33" class="edge">
+<title>40&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1110.47,-719.7C1111.79,-711.98 1113.38,-702.71 1114.85,-694.11"/>
+<polygon fill="black" stroke="black" points="1118.33,-694.55 1116.57,-684.1 1111.43,-693.37 1118.33,-694.55"/>
+</g>
+<!-- 42 -->
+<g id="node34" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1209.5,-612 1041.5,-612 1041.5,-576 1209.5,-576 1209.5,-612"/>
+<text text-anchor="middle" x="1125.5" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge34" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1120.98,-647.7C1121.64,-639.98 1122.44,-630.71 1123.18,-622.11"/>
+<polygon fill="black" stroke="black" points="1126.67,-622.37 1124.03,-612.1 1119.69,-621.77 1126.67,-622.37"/>
+</g>
+<!-- 43 -->
+<g id="node35" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1300,-540 957,-540 957,-504 1300,-504 1300,-540"/>
+<text text-anchor="middle" x="1128.5" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge35" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1126.24,-575.7C1126.57,-567.98 1126.97,-558.71 1127.34,-550.11"/>
+<polygon fill="black" stroke="black" points="1130.84,-550.25 1127.77,-540.1 1123.84,-549.95 1130.84,-550.25"/>
+</g>
+<!-- 56 -->
+<g id="node48" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="1312,-468 1143,-468 1143,-432 1312,-432 1312,-468"/>
+<text text-anchor="middle" x="1227.5" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 43&#45;&gt;56 -->
+<g id="edge48" class="edge">
+<title>43&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1152.72,-503.88C1165.42,-494.89 1181.16,-483.76 1194.92,-474.03"/>
+<polygon fill="black" stroke="black" points="1197.04,-476.82 1203.19,-468.19 1193,-471.11 1197.04,-476.82"/>
+</g>
+<!-- 44 -->
+<g id="node36" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1432.5,-1908 1208.5,-1908 1208.5,-1872 1432.5,-1872 1432.5,-1908"/>
+<text text-anchor="middle" x="1320.5" y="-1886.3" font-family="Times,serif" font-size="14.00">Constant((768, 768), float32)</text>
+</g>
+<!-- 45 -->
+<g id="node37" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="1417,-1836 1224,-1836 1224,-1800 1417,-1800 1417,-1836"/>
+<text text-anchor="middle" x="1320.5" y="-1814.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 44&#45;&gt;45 -->
+<g id="edge36" class="edge">
+<title>44&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M1320.5,-1871.7C1320.5,-1863.98 1320.5,-1854.71 1320.5,-1846.11"/>
+<polygon fill="black" stroke="black" points="1324,-1846.1 1320.5,-1836.1 1317,-1846.1 1324,-1846.1"/>
+</g>
+<!-- 46 -->
+<g id="node38" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="1501,-1764 1140,-1764 1140,-1728 1501,-1728 1501,-1764"/>
+<text text-anchor="middle" x="1320.5" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 45&#45;&gt;46 -->
+<g id="edge37" class="edge">
+<title>45&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M1320.5,-1799.7C1320.5,-1791.98 1320.5,-1782.71 1320.5,-1774.11"/>
+<polygon fill="black" stroke="black" points="1324,-1774.1 1320.5,-1764.1 1317,-1774.1 1324,-1774.1"/>
+</g>
+<!-- 47 -->
+<g id="node39" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="1426,-1692 1215,-1692 1215,-1656 1426,-1656 1426,-1692"/>
+<text text-anchor="middle" x="1320.5" y="-1670.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 46&#45;&gt;47 -->
+<g id="edge38" class="edge">
+<title>46&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1320.5,-1727.7C1320.5,-1719.98 1320.5,-1710.71 1320.5,-1702.11"/>
+<polygon fill="black" stroke="black" points="1324,-1702.1 1320.5,-1692.1 1317,-1702.1 1324,-1702.1"/>
+</g>
+<!-- 47&#45;&gt;48 -->
+<g id="edge40" class="edge">
+<title>47&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1320.5,-1655.7C1320.5,-1647.98 1320.5,-1638.71 1320.5,-1630.11"/>
+<polygon fill="black" stroke="black" points="1324,-1630.1 1320.5,-1620.1 1317,-1630.1 1324,-1630.1"/>
+</g>
+<!-- 49 -->
+<g id="node41" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="1499,-1548 1152,-1548 1152,-1512 1499,-1512 1499,-1548"/>
+<text text-anchor="middle" x="1325.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge41" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1321.74,-1583.7C1322.29,-1575.98 1322.95,-1566.71 1323.56,-1558.11"/>
+<polygon fill="black" stroke="black" points="1327.06,-1558.33 1324.28,-1548.1 1320.07,-1557.83 1327.06,-1558.33"/>
+</g>
+<!-- 51 -->
+<g id="node43" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="1366,-1476 1295,-1476 1295,-1440 1366,-1440 1366,-1476"/>
+<text text-anchor="middle" x="1330.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 49&#45;&gt;51 -->
+<g id="edge42" class="edge">
+<title>49&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1326.74,-1511.7C1327.29,-1503.98 1327.95,-1494.71 1328.56,-1486.11"/>
+<polygon fill="black" stroke="black" points="1332.06,-1486.33 1329.28,-1476.1 1325.07,-1485.83 1332.06,-1486.33"/>
+</g>
+<!-- 50 -->
+<g id="node42" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="1709.5,-1548 1517.5,-1548 1517.5,-1512 1709.5,-1512 1709.5,-1548"/>
+<text text-anchor="middle" x="1613.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge43" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1544.63,-1511.97C1492.31,-1499.02 1421.8,-1481.58 1376.25,-1470.32"/>
+<polygon fill="black" stroke="black" points="1376.9,-1466.87 1366.36,-1467.87 1375.22,-1473.67 1376.9,-1466.87"/>
+</g>
+<!-- 52 -->
+<g id="node44" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="1513,-1404 1148,-1404 1148,-1368 1513,-1368 1513,-1404"/>
+<text text-anchor="middle" x="1330.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge44" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M1330.5,-1439.7C1330.5,-1431.98 1330.5,-1422.71 1330.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="1334,-1414.1 1330.5,-1404.1 1327,-1414.1 1334,-1414.1"/>
+</g>
+<!-- 53 -->
+<g id="node45" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="1445,-1332 1216,-1332 1216,-1296 1445,-1296 1445,-1332"/>
+<text text-anchor="middle" x="1330.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge45" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1330.5,-1367.7C1330.5,-1359.98 1330.5,-1350.71 1330.5,-1342.11"/>
+<polygon fill="black" stroke="black" points="1334,-1342.1 1330.5,-1332.1 1327,-1342.1 1334,-1342.1"/>
+</g>
+<!-- 54 -->
+<g id="node46" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="1502,-1260 1159,-1260 1159,-1224 1502,-1224 1502,-1260"/>
+<text text-anchor="middle" x="1330.5" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge46" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M1330.5,-1295.7C1330.5,-1287.98 1330.5,-1278.71 1330.5,-1270.11"/>
+<polygon fill="black" stroke="black" points="1334,-1270.1 1330.5,-1260.1 1327,-1270.1 1334,-1270.1"/>
+</g>
+<!-- 55 -->
+<g id="node47" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="1436,-1044 1225,-1044 1225,-1008 1436,-1008 1436,-1044"/>
+<text text-anchor="middle" x="1330.5" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge47" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M1330.5,-1223.85C1330.5,-1186.83 1330.5,-1099.18 1330.5,-1054.39"/>
+<polygon fill="black" stroke="black" points="1334,-1054.23 1330.5,-1044.23 1327,-1054.23 1334,-1054.23"/>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge49" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1330.5,-1007.95C1330.5,-981.29 1330.5,-928.11 1330.5,-883 1330.5,-883 1330.5,-883 1330.5,-593 1330.5,-552.36 1332.93,-537.21 1309.5,-504 1300.77,-491.62 1288.14,-481.42 1275.4,-473.4"/>
+<polygon fill="black" stroke="black" points="1276.99,-470.27 1266.6,-468.19 1273.42,-476.3 1276.99,-470.27"/>
+</g>
+<!-- 57 -->
+<g id="node49" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="1410,-396 1045,-396 1045,-360 1410,-360 1410,-396"/>
+<text text-anchor="middle" x="1227.5" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge50" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M1227.5,-431.7C1227.5,-423.98 1227.5,-414.71 1227.5,-406.11"/>
+<polygon fill="black" stroke="black" points="1231,-406.1 1227.5,-396.1 1224,-406.1 1231,-406.1"/>
+</g>
+<!-- 58 -->
+<g id="node50" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="1342,-324 1113,-324 1113,-288 1342,-288 1342,-324"/>
+<text text-anchor="middle" x="1227.5" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge51" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1227.5,-359.7C1227.5,-351.98 1227.5,-342.71 1227.5,-334.11"/>
+<polygon fill="black" stroke="black" points="1231,-334.1 1227.5,-324.1 1224,-334.1 1231,-334.1"/>
+</g>
+<!-- 59 -->
+<g id="node51" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="1260,-252 1195,-252 1195,-216 1260,-216 1260,-252"/>
+<text text-anchor="middle" x="1227.5" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge52" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M1227.5,-287.7C1227.5,-279.98 1227.5,-270.71 1227.5,-262.11"/>
+<polygon fill="black" stroke="black" points="1231,-262.1 1227.5,-252.1 1224,-262.1 1231,-262.1"/>
+</g>
+<!-- 60 -->
+<g id="node52" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="1401,-180 1054,-180 1054,-144 1401,-144 1401,-180"/>
+<text text-anchor="middle" x="1227.5" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge53" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1227.5,-215.7C1227.5,-207.98 1227.5,-198.71 1227.5,-190.11"/>
+<polygon fill="black" stroke="black" points="1231,-190.1 1227.5,-180.1 1224,-190.1 1231,-190.1"/>
+</g>
+<!-- 61 -->
+<g id="node53" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="1270.5,-108 1184.5,-108 1184.5,-72 1270.5,-72 1270.5,-108"/>
+<text text-anchor="middle" x="1227.5" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 60&#45;&gt;61 -->
+<g id="edge54" class="edge">
+<title>60&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M1227.5,-143.7C1227.5,-135.98 1227.5,-126.71 1227.5,-118.11"/>
+<polygon fill="black" stroke="black" points="1231,-118.1 1227.5,-108.1 1224,-118.1 1231,-118.1"/>
+</g>
+<!-- 62 -->
+<g id="node54" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="1267.5,-36 1187.5,-36 1187.5,0 1267.5,0 1267.5,-36"/>
+<text text-anchor="middle" x="1227.5" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge55" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M1227.5,-71.7C1227.5,-63.98 1227.5,-54.71 1227.5,-46.11"/>
+<polygon fill="black" stroke="black" points="1231,-46.1 1227.5,-36.1 1224,-46.1 1231,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert-tvm_72_0.svg b/images/bert-pytorch/bert-tvm_72_0.svg
new file mode 100644
index 0000000..f8b6dca
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_72_0.svg
@@ -0,0 +1,559 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="1718pt" height="1772pt"
+ viewBox="0.00 0.00 1717.50 1772.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1768)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1768 1713.5,-1768 1713.5,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="814.5" cy="-1746" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="814.5" y="-1742.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 10 -->
+<g id="node3" class="node">
+<title>10</title>
+<polygon fill="none" stroke="black" points="990.5,-1692 638.5,-1692 638.5,-1656 990.5,-1656 990.5,-1692"/>
+<text text-anchor="middle" x="814.5" y="-1670.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;10 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;10</title>
+<path fill="none" stroke="black" d="M814.5,-1727.7C814.5,-1719.98 814.5,-1710.71 814.5,-1702.11"/>
+<polygon fill="black" stroke="black" points="818,-1702.1 814.5,-1692.1 811,-1702.1 818,-1702.1"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1084.5" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1084.5" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 33 -->
+<g id="node25" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="1120,-828 1049,-828 1049,-792 1120,-792 1120,-828"/>
+<text text-anchor="middle" x="1084.5" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 1&#45;&gt;33 -->
+<g id="edge25" class="edge">
+<title>1&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M1084.5,-863.7C1084.5,-855.98 1084.5,-846.71 1084.5,-838.11"/>
+<polygon fill="black" stroke="black" points="1088,-838.1 1084.5,-828.1 1081,-838.1 1088,-838.1"/>
+</g>
+<!-- 12 -->
+<g id="node5" class="node">
+<title>12</title>
+<polygon fill="none" stroke="black" points="834,-1620 665,-1620 665,-1584 834,-1584 834,-1620"/>
+<text text-anchor="middle" x="749.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;12 -->
+<g id="edge2" class="edge">
+<title>10&#45;&gt;12</title>
+<path fill="none" stroke="black" d="M798.43,-1655.7C790.56,-1647.22 780.94,-1636.86 772.33,-1627.58"/>
+<polygon fill="black" stroke="black" points="774.75,-1625.05 765.38,-1620.1 769.62,-1629.81 774.75,-1625.05"/>
+</g>
+<!-- 20 -->
+<g id="node13" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="324,-1620 155,-1620 155,-1584 324,-1584 324,-1620"/>
+<text text-anchor="middle" x="239.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;20 -->
+<g id="edge10" class="edge">
+<title>10&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M674.57,-1655.97C569,-1643.11 426.99,-1625.83 334.38,-1614.55"/>
+<polygon fill="black" stroke="black" points="334.53,-1611.04 324.18,-1613.31 333.68,-1617.99 334.53,-1611.04"/>
+</g>
+<!-- 39 -->
+<g id="node31" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="1344,-1620 1175,-1620 1175,-1584 1344,-1584 1344,-1620"/>
+<text text-anchor="middle" x="1259.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;39 -->
+<g id="edge30" class="edge">
+<title>10&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M922.79,-1655.97C996.72,-1644.34 1093.73,-1629.08 1164.62,-1617.92"/>
+<polygon fill="black" stroke="black" points="1165.55,-1621.32 1174.89,-1616.31 1164.47,-1614.41 1165.55,-1621.32"/>
+</g>
+<!-- 11 -->
+<g id="node4" class="node">
+<title>11</title>
+<polygon fill="none" stroke="black" points="620.5,-1692 378.5,-1692 378.5,-1656 620.5,-1656 620.5,-1692"/>
+<text text-anchor="middle" x="499.5" y="-1670.3" font-family="Times,serif" font-size="14.00">Constant((1, 768, 768), float32)</text>
+</g>
+<!-- 11&#45;&gt;12 -->
+<g id="edge3" class="edge">
+<title>11&#45;&gt;12</title>
+<path fill="none" stroke="black" d="M560.34,-1655.97C596.03,-1645.97 641.29,-1633.3 678.59,-1622.85"/>
+<polygon fill="black" stroke="black" points="679.7,-1626.18 688.39,-1620.11 677.82,-1619.44 679.7,-1626.18"/>
+</g>
+<!-- 13 -->
+<g id="node6" class="node">
+<title>13</title>
+<polygon fill="none" stroke="black" points="923,-1548 576,-1548 576,-1512 923,-1512 923,-1548"/>
+<text text-anchor="middle" x="749.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 12&#45;&gt;13 -->
+<g id="edge4" class="edge">
+<title>12&#45;&gt;13</title>
+<path fill="none" stroke="black" d="M749.5,-1583.7C749.5,-1575.98 749.5,-1566.71 749.5,-1558.11"/>
+<polygon fill="black" stroke="black" points="753,-1558.1 749.5,-1548.1 746,-1558.1 753,-1558.1"/>
+</g>
+<!-- 15 -->
+<g id="node8" class="node">
+<title>15</title>
+<polygon fill="none" stroke="black" points="869,-1476 798,-1476 798,-1440 869,-1440 869,-1476"/>
+<text text-anchor="middle" x="833.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 13&#45;&gt;15 -->
+<g id="edge5" class="edge">
+<title>13&#45;&gt;15</title>
+<path fill="none" stroke="black" d="M770.26,-1511.7C780.74,-1502.97 793.61,-1492.24 805,-1482.75"/>
+<polygon fill="black" stroke="black" points="807.53,-1485.19 812.97,-1476.1 803.05,-1479.82 807.53,-1485.19"/>
+</g>
+<!-- 14 -->
+<g id="node7" class="node">
+<title>14</title>
+<polygon fill="none" stroke="black" points="1133.5,-1548 941.5,-1548 941.5,-1512 1133.5,-1512 1133.5,-1548"/>
+<text text-anchor="middle" x="1037.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 14&#45;&gt;15 -->
+<g id="edge6" class="edge">
+<title>14&#45;&gt;15</title>
+<path fill="none" stroke="black" d="M987.86,-1511.97C954.62,-1500.56 911.2,-1485.66 878.88,-1474.57"/>
+<polygon fill="black" stroke="black" points="879.83,-1471.2 869.23,-1471.26 877.55,-1477.82 879.83,-1471.2"/>
+</g>
+<!-- 16 -->
+<g id="node9" class="node">
+<title>16</title>
+<polygon fill="none" stroke="black" points="1016,-1404 651,-1404 651,-1368 1016,-1368 1016,-1404"/>
+<text text-anchor="middle" x="833.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 15&#45;&gt;16 -->
+<g id="edge7" class="edge">
+<title>15&#45;&gt;16</title>
+<path fill="none" stroke="black" d="M833.5,-1439.7C833.5,-1431.98 833.5,-1422.71 833.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="837,-1414.1 833.5,-1404.1 830,-1414.1 837,-1414.1"/>
+</g>
+<!-- 17 -->
+<g id="node10" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="921,-1332 692,-1332 692,-1296 921,-1296 921,-1332"/>
+<text text-anchor="middle" x="806.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 16&#45;&gt;17 -->
+<g id="edge8" class="edge">
+<title>16&#45;&gt;17</title>
+<path fill="none" stroke="black" d="M826.83,-1367.7C823.79,-1359.81 820.12,-1350.3 816.74,-1341.55"/>
+<polygon fill="black" stroke="black" points="819.96,-1340.17 813.1,-1332.1 813.43,-1342.69 819.96,-1340.17"/>
+</g>
+<!-- 18 -->
+<g id="node11" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="973,-1116 630,-1116 630,-1080 973,-1080 973,-1116"/>
+<text text-anchor="middle" x="801.5" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge9" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M806.1,-1295.85C805.23,-1258.83 803.19,-1171.18 802.14,-1126.39"/>
+<polygon fill="black" stroke="black" points="805.64,-1126.15 801.9,-1116.23 798.64,-1126.31 805.64,-1126.15"/>
+</g>
+<!-- 29 -->
+<g id="node22" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="880,-1044 711,-1044 711,-1008 880,-1008 880,-1044"/>
+<text text-anchor="middle" x="795.5" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 18&#45;&gt;29 -->
+<g id="edge20" class="edge">
+<title>18&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M800.02,-1079.7C799.36,-1071.98 798.56,-1062.71 797.82,-1054.11"/>
+<polygon fill="black" stroke="black" points="801.31,-1053.77 796.97,-1044.1 794.33,-1054.37 801.31,-1053.77"/>
+</g>
+<!-- 19 -->
+<g id="node12" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="360.5,-1692 118.5,-1692 118.5,-1656 360.5,-1656 360.5,-1692"/>
+<text text-anchor="middle" x="239.5" y="-1670.3" font-family="Times,serif" font-size="14.00">Constant((1, 768, 768), float32)</text>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge11" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M239.5,-1655.7C239.5,-1647.98 239.5,-1638.71 239.5,-1630.11"/>
+<polygon fill="black" stroke="black" points="243,-1630.1 239.5,-1620.1 236,-1630.1 243,-1630.1"/>
+</g>
+<!-- 21 -->
+<g id="node14" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="347,-1548 0,-1548 0,-1512 347,-1512 347,-1548"/>
+<text text-anchor="middle" x="173.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge12" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M223.19,-1583.7C215.19,-1575.22 205.43,-1564.86 196.68,-1555.58"/>
+<polygon fill="black" stroke="black" points="199.03,-1552.98 189.63,-1548.1 193.94,-1557.78 199.03,-1552.98"/>
+</g>
+<!-- 23 -->
+<g id="node16" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="486,-1476 415,-1476 415,-1440 486,-1440 486,-1476"/>
+<text text-anchor="middle" x="450.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 21&#45;&gt;23 -->
+<g id="edge13" class="edge">
+<title>21&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M240.91,-1511.97C291.72,-1499.13 360.04,-1481.86 404.66,-1470.58"/>
+<polygon fill="black" stroke="black" points="405.8,-1473.91 414.64,-1468.06 404.09,-1467.12 405.8,-1473.91"/>
+</g>
+<!-- 22 -->
+<g id="node15" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="557.5,-1548 365.5,-1548 365.5,-1512 557.5,-1512 557.5,-1548"/>
+<text text-anchor="middle" x="461.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge14" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M458.78,-1511.7C457.57,-1503.98 456.11,-1494.71 454.76,-1486.11"/>
+<polygon fill="black" stroke="black" points="458.2,-1485.44 453.19,-1476.1 451.28,-1486.53 458.2,-1485.44"/>
+</g>
+<!-- 24 -->
+<g id="node17" class="node">
+<title>24</title>
+<polygon fill="none" stroke="black" points="633,-1404 268,-1404 268,-1368 633,-1368 633,-1404"/>
+<text text-anchor="middle" x="450.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 23&#45;&gt;24 -->
+<g id="edge15" class="edge">
+<title>23&#45;&gt;24</title>
+<path fill="none" stroke="black" d="M450.5,-1439.7C450.5,-1431.98 450.5,-1422.71 450.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="454,-1414.1 450.5,-1404.1 447,-1414.1 454,-1414.1"/>
+</g>
+<!-- 25 -->
+<g id="node18" class="node">
+<title>25</title>
+<polygon fill="none" stroke="black" points="566,-1332 337,-1332 337,-1296 566,-1296 566,-1332"/>
+<text text-anchor="middle" x="451.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 24&#45;&gt;25 -->
+<g id="edge16" class="edge">
+<title>24&#45;&gt;25</title>
+<path fill="none" stroke="black" d="M450.75,-1367.7C450.86,-1359.98 450.99,-1350.71 451.11,-1342.11"/>
+<polygon fill="black" stroke="black" points="454.61,-1342.15 451.26,-1332.1 447.61,-1342.05 454.61,-1342.15"/>
+</g>
+<!-- 26 -->
+<g id="node19" class="node">
+<title>26</title>
+<polygon fill="none" stroke="black" points="596,-1260 367,-1260 367,-1224 596,-1224 596,-1260"/>
+<text text-anchor="middle" x="481.5" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 25&#45;&gt;26 -->
+<g id="edge17" class="edge">
+<title>25&#45;&gt;26</title>
+<path fill="none" stroke="black" d="M458.92,-1295.7C462.29,-1287.81 466.37,-1278.3 470.12,-1269.55"/>
+<polygon fill="black" stroke="black" points="473.45,-1270.67 474.17,-1260.1 467.01,-1267.92 473.45,-1270.67"/>
+</g>
+<!-- 27 -->
+<g id="node20" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="668,-1188 325,-1188 325,-1152 668,-1152 668,-1188"/>
+<text text-anchor="middle" x="496.5" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 26&#45;&gt;27 -->
+<g id="edge18" class="edge">
+<title>26&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M485.21,-1223.7C486.86,-1215.98 488.85,-1206.71 490.69,-1198.11"/>
+<polygon fill="black" stroke="black" points="494.16,-1198.62 492.83,-1188.1 487.32,-1197.15 494.16,-1198.62"/>
+</g>
+<!-- 28 -->
+<g id="node21" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="610,-1116 399,-1116 399,-1080 610,-1080 610,-1116"/>
+<text text-anchor="middle" x="504.5" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 27&#45;&gt;28 -->
+<g id="edge19" class="edge">
+<title>27&#45;&gt;28</title>
+<path fill="none" stroke="black" d="M498.48,-1151.7C499.36,-1143.98 500.42,-1134.71 501.4,-1126.11"/>
+<polygon fill="black" stroke="black" points="504.89,-1126.44 502.55,-1116.1 497.93,-1125.64 504.89,-1126.44"/>
+</g>
+<!-- 28&#45;&gt;29 -->
+<g id="edge21" class="edge">
+<title>28&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M575.31,-1079.97C617.47,-1069.83 671.09,-1056.93 714.87,-1046.4"/>
+<polygon fill="black" stroke="black" points="715.72,-1049.79 724.62,-1044.05 714.08,-1042.99 715.72,-1049.79"/>
+</g>
+<!-- 30 -->
+<g id="node23" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="978,-972 613,-972 613,-936 978,-936 978,-972"/>
+<text text-anchor="middle" x="795.5" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge22" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M795.5,-1007.7C795.5,-999.98 795.5,-990.71 795.5,-982.11"/>
+<polygon fill="black" stroke="black" points="799,-982.1 795.5,-972.1 792,-982.1 799,-982.1"/>
+</g>
+<!-- 32 -->
+<g id="node24" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="849,-900 742,-900 742,-864 849,-864 849,-900"/>
+<text text-anchor="middle" x="795.5" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 30&#45;&gt;32 -->
+<g id="edge23" class="edge">
+<title>30&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M795.5,-935.7C795.5,-927.98 795.5,-918.71 795.5,-910.11"/>
+<polygon fill="black" stroke="black" points="799,-910.1 795.5,-900.1 792,-910.1 799,-910.1"/>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge24" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M849.24,-866.13C852.03,-865.4 854.8,-864.69 857.5,-864 920.04,-848.13 992.73,-831.47 1038.67,-821.16"/>
+<polygon fill="black" stroke="black" points="1039.63,-824.53 1048.62,-818.93 1038.1,-817.7 1039.63,-824.53"/>
+</g>
+<!-- 34 -->
+<g id="node26" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="1174.5,-756 1000.5,-756 1000.5,-720 1174.5,-720 1174.5,-756"/>
+<text text-anchor="middle" x="1087.5" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge26" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M1085.24,-791.7C1085.57,-783.98 1085.97,-774.71 1086.34,-766.11"/>
+<polygon fill="black" stroke="black" points="1089.84,-766.25 1086.77,-756.1 1082.84,-765.95 1089.84,-766.25"/>
+</g>
+<!-- 35 -->
+<g id="node27" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="1202,-684 1019,-684 1019,-648 1202,-648 1202,-684"/>
+<text text-anchor="middle" x="1110.5" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge27" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M1093.19,-719.7C1095.75,-711.9 1098.83,-702.51 1101.68,-693.83"/>
+<polygon fill="black" stroke="black" points="1105.08,-694.7 1104.88,-684.1 1098.43,-692.51 1105.08,-694.7"/>
+</g>
+<!-- 36 -->
+<g id="node28" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="1206.5,-612 1038.5,-612 1038.5,-576 1206.5,-576 1206.5,-612"/>
+<text text-anchor="middle" x="1122.5" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge28" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M1113.47,-647.7C1114.79,-639.98 1116.38,-630.71 1117.85,-622.11"/>
+<polygon fill="black" stroke="black" points="1121.33,-622.55 1119.57,-612.1 1114.43,-621.37 1121.33,-622.55"/>
+</g>
+<!-- 37 -->
+<g id="node29" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1300,-540 957,-540 957,-504 1300,-504 1300,-540"/>
+<text text-anchor="middle" x="1128.5" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge29" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M1123.98,-575.7C1124.64,-567.98 1125.44,-558.71 1126.18,-550.11"/>
+<polygon fill="black" stroke="black" points="1129.67,-550.37 1127.03,-540.1 1122.69,-549.77 1129.67,-550.37"/>
+</g>
+<!-- 47 -->
+<g id="node39" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="1312,-468 1143,-468 1143,-432 1312,-432 1312,-468"/>
+<text text-anchor="middle" x="1227.5" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 37&#45;&gt;47 -->
+<g id="edge39" class="edge">
+<title>37&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1152.72,-503.88C1165.42,-494.89 1181.16,-483.76 1194.92,-474.03"/>
+<polygon fill="black" stroke="black" points="1197.04,-476.82 1203.19,-468.19 1193,-471.11 1197.04,-476.82"/>
+</g>
+<!-- 38 -->
+<g id="node30" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="1380.5,-1692 1138.5,-1692 1138.5,-1656 1380.5,-1656 1380.5,-1692"/>
+<text text-anchor="middle" x="1259.5" y="-1670.3" font-family="Times,serif" font-size="14.00">Constant((1, 768, 768), float32)</text>
+</g>
+<!-- 38&#45;&gt;39 -->
+<g id="edge31" class="edge">
+<title>38&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M1259.5,-1655.7C1259.5,-1647.98 1259.5,-1638.71 1259.5,-1630.11"/>
+<polygon fill="black" stroke="black" points="1263,-1630.1 1259.5,-1620.1 1256,-1630.1 1263,-1630.1"/>
+</g>
+<!-- 40 -->
+<g id="node32" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="1499,-1548 1152,-1548 1152,-1512 1499,-1512 1499,-1548"/>
+<text text-anchor="middle" x="1325.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 39&#45;&gt;40 -->
+<g id="edge32" class="edge">
+<title>39&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1275.81,-1583.7C1283.81,-1575.22 1293.57,-1564.86 1302.32,-1555.58"/>
+<polygon fill="black" stroke="black" points="1305.06,-1557.78 1309.37,-1548.1 1299.97,-1552.98 1305.06,-1557.78"/>
+</g>
+<!-- 42 -->
+<g id="node34" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1366,-1476 1295,-1476 1295,-1440 1366,-1440 1366,-1476"/>
+<text text-anchor="middle" x="1330.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 40&#45;&gt;42 -->
+<g id="edge33" class="edge">
+<title>40&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1326.74,-1511.7C1327.29,-1503.98 1327.95,-1494.71 1328.56,-1486.11"/>
+<polygon fill="black" stroke="black" points="1332.06,-1486.33 1329.28,-1476.1 1325.07,-1485.83 1332.06,-1486.33"/>
+</g>
+<!-- 41 -->
+<g id="node33" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1709.5,-1548 1517.5,-1548 1517.5,-1512 1709.5,-1512 1709.5,-1548"/>
+<text text-anchor="middle" x="1613.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge34" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1544.63,-1511.97C1492.31,-1499.02 1421.8,-1481.58 1376.25,-1470.32"/>
+<polygon fill="black" stroke="black" points="1376.9,-1466.87 1366.36,-1467.87 1375.22,-1473.67 1376.9,-1466.87"/>
+</g>
+<!-- 43 -->
+<g id="node35" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1513,-1404 1148,-1404 1148,-1368 1513,-1368 1513,-1404"/>
+<text text-anchor="middle" x="1330.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge35" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1330.5,-1439.7C1330.5,-1431.98 1330.5,-1422.71 1330.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="1334,-1414.1 1330.5,-1404.1 1327,-1414.1 1334,-1414.1"/>
+</g>
+<!-- 44 -->
+<g id="node36" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1445,-1332 1216,-1332 1216,-1296 1445,-1296 1445,-1332"/>
+<text text-anchor="middle" x="1330.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge36" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M1330.5,-1367.7C1330.5,-1359.98 1330.5,-1350.71 1330.5,-1342.11"/>
+<polygon fill="black" stroke="black" points="1334,-1342.1 1330.5,-1332.1 1327,-1342.1 1334,-1342.1"/>
+</g>
+<!-- 45 -->
+<g id="node37" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="1502,-1260 1159,-1260 1159,-1224 1502,-1224 1502,-1260"/>
+<text text-anchor="middle" x="1330.5" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 44&#45;&gt;45 -->
+<g id="edge37" class="edge">
+<title>44&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M1330.5,-1295.7C1330.5,-1287.98 1330.5,-1278.71 1330.5,-1270.11"/>
+<polygon fill="black" stroke="black" points="1334,-1270.1 1330.5,-1260.1 1327,-1270.1 1334,-1270.1"/>
+</g>
+<!-- 46 -->
+<g id="node38" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="1436,-1044 1225,-1044 1225,-1008 1436,-1008 1436,-1044"/>
+<text text-anchor="middle" x="1330.5" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 45&#45;&gt;46 -->
+<g id="edge38" class="edge">
+<title>45&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M1330.5,-1223.85C1330.5,-1186.83 1330.5,-1099.18 1330.5,-1054.39"/>
+<polygon fill="black" stroke="black" points="1334,-1054.23 1330.5,-1044.23 1327,-1054.23 1334,-1054.23"/>
+</g>
+<!-- 46&#45;&gt;47 -->
+<g id="edge40" class="edge">
+<title>46&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1330.5,-1007.95C1330.5,-981.29 1330.5,-928.11 1330.5,-883 1330.5,-883 1330.5,-883 1330.5,-593 1330.5,-552.36 1332.93,-537.21 1309.5,-504 1300.77,-491.62 1288.14,-481.42 1275.4,-473.4"/>
+<polygon fill="black" stroke="black" points="1276.99,-470.27 1266.6,-468.19 1273.42,-476.3 1276.99,-470.27"/>
+</g>
+<!-- 48 -->
+<g id="node40" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="1410,-396 1045,-396 1045,-360 1410,-360 1410,-396"/>
+<text text-anchor="middle" x="1227.5" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 47&#45;&gt;48 -->
+<g id="edge41" class="edge">
+<title>47&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1227.5,-431.7C1227.5,-423.98 1227.5,-414.71 1227.5,-406.11"/>
+<polygon fill="black" stroke="black" points="1231,-406.1 1227.5,-396.1 1224,-406.1 1231,-406.1"/>
+</g>
+<!-- 49 -->
+<g id="node41" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="1342,-324 1113,-324 1113,-288 1342,-288 1342,-324"/>
+<text text-anchor="middle" x="1227.5" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge42" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M1227.5,-359.7C1227.5,-351.98 1227.5,-342.71 1227.5,-334.11"/>
+<polygon fill="black" stroke="black" points="1231,-334.1 1227.5,-324.1 1224,-334.1 1231,-334.1"/>
+</g>
+<!-- 50 -->
+<g id="node42" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="1260,-252 1195,-252 1195,-216 1260,-216 1260,-252"/>
+<text text-anchor="middle" x="1227.5" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge43" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M1227.5,-287.7C1227.5,-279.98 1227.5,-270.71 1227.5,-262.11"/>
+<polygon fill="black" stroke="black" points="1231,-262.1 1227.5,-252.1 1224,-262.1 1231,-262.1"/>
+</g>
+<!-- 51 -->
+<g id="node43" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="1401,-180 1054,-180 1054,-144 1401,-144 1401,-180"/>
+<text text-anchor="middle" x="1227.5" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge44" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1227.5,-215.7C1227.5,-207.98 1227.5,-198.71 1227.5,-190.11"/>
+<polygon fill="black" stroke="black" points="1231,-190.1 1227.5,-180.1 1224,-190.1 1231,-190.1"/>
+</g>
+<!-- 52 -->
+<g id="node44" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="1270.5,-108 1184.5,-108 1184.5,-72 1270.5,-72 1270.5,-108"/>
+<text text-anchor="middle" x="1227.5" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge45" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M1227.5,-143.7C1227.5,-135.98 1227.5,-126.71 1227.5,-118.11"/>
+<polygon fill="black" stroke="black" points="1231,-118.1 1227.5,-108.1 1224,-118.1 1231,-118.1"/>
+</g>
+<!-- 53 -->
+<g id="node45" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="1267.5,-36 1187.5,-36 1187.5,0 1267.5,0 1267.5,-36"/>
+<text text-anchor="middle" x="1227.5" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge46" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1227.5,-71.7C1227.5,-63.98 1227.5,-54.71 1227.5,-46.11"/>
+<polygon fill="black" stroke="black" points="1231,-46.1 1227.5,-36.1 1224,-46.1 1231,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert-tvm_74_0.svg b/images/bert-pytorch/bert-tvm_74_0.svg
new file mode 100644
index 0000000..f7a2ace
--- /dev/null
+++ b/images/bert-pytorch/bert-tvm_74_0.svg
@@ -0,0 +1,547 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="2649pt" height="1844pt"
+ viewBox="0.00 0.00 2648.50 1844.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1840)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1840 2644.5,-1840 2644.5,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="1120.5" cy="-1818" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="1120.5" y="-1814.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 11 -->
+<g id="node3" class="node">
+<title>11</title>
+<polygon fill="none" stroke="black" points="1296.5,-1764 944.5,-1764 944.5,-1728 1296.5,-1728 1296.5,-1764"/>
+<text text-anchor="middle" x="1120.5" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;11 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;11</title>
+<path fill="none" stroke="black" d="M1120.5,-1799.7C1120.5,-1791.98 1120.5,-1782.71 1120.5,-1774.11"/>
+<polygon fill="black" stroke="black" points="1124,-1774.1 1120.5,-1764.1 1117,-1774.1 1124,-1774.1"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1640.5" cy="-882" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1640.5" y="-878.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 40 -->
+<g id="node25" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="1676,-828 1605,-828 1605,-792 1676,-792 1676,-828"/>
+<text text-anchor="middle" x="1640.5" y="-806.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 1&#45;&gt;40 -->
+<g id="edge25" class="edge">
+<title>1&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1640.5,-863.7C1640.5,-855.98 1640.5,-846.71 1640.5,-838.11"/>
+<polygon fill="black" stroke="black" points="1644,-838.1 1640.5,-828.1 1637,-838.1 1644,-838.1"/>
+</g>
+<!-- 13 -->
+<g id="node5" class="node">
+<title>13</title>
+<polygon fill="none" stroke="black" points="1364,-1692 1195,-1692 1195,-1656 1364,-1656 1364,-1692"/>
+<text text-anchor="middle" x="1279.5" y="-1670.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 11&#45;&gt;13 -->
+<g id="edge2" class="edge">
+<title>11&#45;&gt;13</title>
+<path fill="none" stroke="black" d="M1159.4,-1727.88C1181.13,-1718.31 1208.4,-1706.3 1231.46,-1696.15"/>
+<polygon fill="black" stroke="black" points="1232.99,-1699.3 1240.73,-1692.07 1230.17,-1692.89 1232.99,-1699.3"/>
+</g>
+<!-- 12 -->
+<g id="node4" class="node">
+<title>12</title>
+<polygon fill="none" stroke="black" points="1565,-1764 1314,-1764 1314,-1728 1565,-1728 1565,-1764"/>
+<text text-anchor="middle" x="1439.5" y="-1742.3" font-family="Times,serif" font-size="14.00">Constant((1, 2304, 768), float32)</text>
+</g>
+<!-- 12&#45;&gt;13 -->
+<g id="edge3" class="edge">
+<title>12&#45;&gt;13</title>
+<path fill="none" stroke="black" d="M1400.36,-1727.88C1378.49,-1718.31 1351.04,-1706.3 1327.84,-1696.15"/>
+<polygon fill="black" stroke="black" points="1329.08,-1692.87 1318.51,-1692.07 1326.27,-1699.28 1329.08,-1692.87"/>
+</g>
+<!-- 17 -->
+<g id="node6" class="node">
+<title>17</title>
+<polygon fill="none" stroke="black" points="823,-1620 0,-1620 0,-1584 823,-1584 823,-1620"/>
+<text text-anchor="middle" x="411.5" y="-1598.3" font-family="Times,serif" font-size="14.00">strided_slice(·, [0 0 0], [ &#45;1 &#160;&#45;1 768], [1 1 1]| begin=[0, 0, 0], end=[&#45;1, &#45;1, 768], strides=[1, 1, 1], slice_mode=size)</text>
+</g>
+<!-- 13&#45;&gt;17 -->
+<g id="edge4" class="edge">
+<title>13&#45;&gt;17</title>
+<path fill="none" stroke="black" d="M1194.69,-1666.16C1065.57,-1655.75 815.4,-1635.57 632.43,-1620.82"/>
+<polygon fill="black" stroke="black" points="632.65,-1617.32 622.4,-1620.01 632.09,-1624.3 632.65,-1617.32"/>
+</g>
+<!-- 27 -->
+<g id="node13" class="node">
+<title>27</title>
+<polygon fill="none" stroke="black" points="1718,-1620 841,-1620 841,-1584 1718,-1584 1718,-1620"/>
+<text text-anchor="middle" x="1279.5" y="-1598.3" font-family="Times,serif" font-size="14.00">strided_slice(·, [ &#160;0 &#160;&#160;0 768], [ &#45;1 &#160;&#45;1 768], [1 1 1]| begin=[0, 0, 768], end=[&#45;1, &#45;1, 768], strides=[1, 1, 1], slice_mode=size)</text>
+</g>
+<!-- 13&#45;&gt;27 -->
+<g id="edge11" class="edge">
+<title>13&#45;&gt;27</title>
+<path fill="none" stroke="black" d="M1279.5,-1655.7C1279.5,-1647.98 1279.5,-1638.71 1279.5,-1630.11"/>
+<polygon fill="black" stroke="black" points="1283,-1630.1 1279.5,-1620.1 1276,-1630.1 1283,-1630.1"/>
+</g>
+<!-- 48 -->
+<g id="node30" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="2640.5,-1620 1736.5,-1620 1736.5,-1584 2640.5,-1584 2640.5,-1620"/>
+<text text-anchor="middle" x="2188.5" y="-1598.3" font-family="Times,serif" font-size="14.00">strided_slice(·, [ &#160;&#160;0 &#160;&#160;&#160;0 1536], [ &#45;1 &#160;&#45;1 768], [1 1 1]| begin=[0, 0, 1536], end=[&#45;1, &#45;1, 768], strides=[1, 1, 1], slice_mode=size)</text>
+</g>
+<!-- 13&#45;&gt;48 -->
+<g id="edge30" class="edge">
+<title>13&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M1364.2,-1666.48C1498.28,-1656.15 1764.03,-1635.69 1957.65,-1620.78"/>
+<polygon fill="black" stroke="black" points="1957.98,-1624.26 1967.68,-1620 1957.44,-1617.28 1957.98,-1624.26"/>
+</g>
+<!-- 18 -->
+<g id="node7" class="node">
+<title>18</title>
+<polygon fill="none" stroke="black" points="631,-1548 284,-1548 284,-1512 631,-1512 631,-1548"/>
+<text text-anchor="middle" x="457.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 17&#45;&gt;18 -->
+<g id="edge5" class="edge">
+<title>17&#45;&gt;18</title>
+<path fill="none" stroke="black" d="M422.87,-1583.7C428.22,-1575.56 434.7,-1565.69 440.61,-1556.7"/>
+<polygon fill="black" stroke="black" points="443.69,-1558.38 446.26,-1548.1 437.84,-1554.54 443.69,-1558.38"/>
+</g>
+<!-- 20 -->
+<g id="node9" class="node">
+<title>20</title>
+<polygon fill="none" stroke="black" points="806,-1476 735,-1476 735,-1440 806,-1440 806,-1476"/>
+<text text-anchor="middle" x="770.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 18&#45;&gt;20 -->
+<g id="edge6" class="edge">
+<title>18&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M533.67,-1511.97C593.43,-1498.6 674.63,-1480.44 724.75,-1469.23"/>
+<polygon fill="black" stroke="black" points="725.65,-1472.62 734.65,-1467.02 724.12,-1465.79 725.65,-1472.62"/>
+</g>
+<!-- 19 -->
+<g id="node8" class="node">
+<title>19</title>
+<polygon fill="none" stroke="black" points="866.5,-1548 674.5,-1548 674.5,-1512 866.5,-1512 866.5,-1548"/>
+<text text-anchor="middle" x="770.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 19&#45;&gt;20 -->
+<g id="edge7" class="edge">
+<title>19&#45;&gt;20</title>
+<path fill="none" stroke="black" d="M770.5,-1511.7C770.5,-1503.98 770.5,-1494.71 770.5,-1486.11"/>
+<polygon fill="black" stroke="black" points="774,-1486.1 770.5,-1476.1 767,-1486.1 774,-1486.1"/>
+</g>
+<!-- 21 -->
+<g id="node10" class="node">
+<title>21</title>
+<polygon fill="none" stroke="black" points="1051,-1404 686,-1404 686,-1368 1051,-1368 1051,-1404"/>
+<text text-anchor="middle" x="868.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 20&#45;&gt;21 -->
+<g id="edge8" class="edge">
+<title>20&#45;&gt;21</title>
+<path fill="none" stroke="black" d="M794.47,-1439.88C807.05,-1430.89 822.63,-1419.76 836.25,-1410.03"/>
+<polygon fill="black" stroke="black" points="838.33,-1412.85 844.43,-1404.19 834.26,-1407.15 838.33,-1412.85"/>
+</g>
+<!-- 22 -->
+<g id="node11" class="node">
+<title>22</title>
+<polygon fill="none" stroke="black" points="1043,-1332 814,-1332 814,-1296 1043,-1296 1043,-1332"/>
+<text text-anchor="middle" x="928.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 21&#45;&gt;22 -->
+<g id="edge9" class="edge">
+<title>21&#45;&gt;22</title>
+<path fill="none" stroke="black" d="M883.33,-1367.7C890.52,-1359.3 899.3,-1349.07 907.19,-1339.86"/>
+<polygon fill="black" stroke="black" points="909.99,-1341.97 913.84,-1332.1 904.67,-1337.42 909.99,-1341.97"/>
+</g>
+<!-- 23 -->
+<g id="node12" class="node">
+<title>23</title>
+<polygon fill="none" stroke="black" points="1166,-1116 823,-1116 823,-1080 1166,-1080 1166,-1116"/>
+<text text-anchor="middle" x="994.5" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 22&#45;&gt;23 -->
+<g id="edge10" class="edge">
+<title>22&#45;&gt;23</title>
+<path fill="none" stroke="black" d="M933.79,-1295.85C945.25,-1258.68 972.47,-1170.44 986.23,-1125.82"/>
+<polygon fill="black" stroke="black" points="989.58,-1126.82 989.19,-1116.23 982.89,-1124.76 989.58,-1126.82"/>
+</g>
+<!-- 36 -->
+<g id="node22" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="1385,-1044 1216,-1044 1216,-1008 1385,-1008 1385,-1044"/>
+<text text-anchor="middle" x="1300.5" y="-1022.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 23&#45;&gt;36 -->
+<g id="edge20" class="edge">
+<title>23&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M1068.96,-1079.97C1113.38,-1069.8 1169.92,-1056.87 1216,-1046.33"/>
+<polygon fill="black" stroke="black" points="1217,-1049.69 1225.97,-1044.05 1215.44,-1042.87 1217,-1049.69"/>
+</g>
+<!-- 28 -->
+<g id="node14" class="node">
+<title>28</title>
+<polygon fill="none" stroke="black" points="1453,-1548 1106,-1548 1106,-1512 1453,-1512 1453,-1548"/>
+<text text-anchor="middle" x="1279.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 27&#45;&gt;28 -->
+<g id="edge12" class="edge">
+<title>27&#45;&gt;28</title>
+<path fill="none" stroke="black" d="M1279.5,-1583.7C1279.5,-1575.98 1279.5,-1566.71 1279.5,-1558.11"/>
+<polygon fill="black" stroke="black" points="1283,-1558.1 1279.5,-1548.1 1276,-1558.1 1283,-1558.1"/>
+</g>
+<!-- 30 -->
+<g id="node16" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="1336,-1476 1265,-1476 1265,-1440 1336,-1440 1336,-1476"/>
+<text text-anchor="middle" x="1300.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 28&#45;&gt;30 -->
+<g id="edge13" class="edge">
+<title>28&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M1284.69,-1511.7C1287.03,-1503.9 1289.85,-1494.51 1292.45,-1485.83"/>
+<polygon fill="black" stroke="black" points="1295.85,-1486.69 1295.37,-1476.1 1289.14,-1484.68 1295.85,-1486.69"/>
+</g>
+<!-- 29 -->
+<g id="node15" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="1663.5,-1548 1471.5,-1548 1471.5,-1512 1663.5,-1512 1663.5,-1548"/>
+<text text-anchor="middle" x="1567.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 29&#45;&gt;30 -->
+<g id="edge14" class="edge">
+<title>29&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M1502.53,-1511.97C1454.08,-1499.26 1389.12,-1482.23 1346.09,-1470.95"/>
+<polygon fill="black" stroke="black" points="1346.75,-1467.51 1336.19,-1468.36 1344.97,-1474.28 1346.75,-1467.51"/>
+</g>
+<!-- 31 -->
+<g id="node17" class="node">
+<title>31</title>
+<polygon fill="none" stroke="black" points="1483,-1404 1118,-1404 1118,-1368 1483,-1368 1483,-1404"/>
+<text text-anchor="middle" x="1300.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 30&#45;&gt;31 -->
+<g id="edge15" class="edge">
+<title>30&#45;&gt;31</title>
+<path fill="none" stroke="black" d="M1300.5,-1439.7C1300.5,-1431.98 1300.5,-1422.71 1300.5,-1414.11"/>
+<polygon fill="black" stroke="black" points="1304,-1414.1 1300.5,-1404.1 1297,-1414.1 1304,-1414.1"/>
+</g>
+<!-- 32 -->
+<g id="node18" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="1415,-1332 1186,-1332 1186,-1296 1415,-1296 1415,-1332"/>
+<text text-anchor="middle" x="1300.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 31&#45;&gt;32 -->
+<g id="edge16" class="edge">
+<title>31&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M1300.5,-1367.7C1300.5,-1359.98 1300.5,-1350.71 1300.5,-1342.11"/>
+<polygon fill="black" stroke="black" points="1304,-1342.1 1300.5,-1332.1 1297,-1342.1 1304,-1342.1"/>
+</g>
+<!-- 33 -->
+<g id="node19" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="1415,-1260 1186,-1260 1186,-1224 1415,-1224 1415,-1260"/>
+<text text-anchor="middle" x="1300.5" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge17" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M1300.5,-1295.7C1300.5,-1287.98 1300.5,-1278.71 1300.5,-1270.11"/>
+<polygon fill="black" stroke="black" points="1304,-1270.1 1300.5,-1260.1 1297,-1270.1 1304,-1270.1"/>
+</g>
+<!-- 34 -->
+<g id="node20" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="1472,-1188 1129,-1188 1129,-1152 1472,-1152 1472,-1188"/>
+<text text-anchor="middle" x="1300.5" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge18" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M1300.5,-1223.7C1300.5,-1215.98 1300.5,-1206.71 1300.5,-1198.11"/>
+<polygon fill="black" stroke="black" points="1304,-1198.1 1300.5,-1188.1 1297,-1198.1 1304,-1198.1"/>
+</g>
+<!-- 35 -->
+<g id="node21" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="1406,-1116 1195,-1116 1195,-1080 1406,-1080 1406,-1116"/>
+<text text-anchor="middle" x="1300.5" y="-1094.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 34&#45;&gt;35 -->
+<g id="edge19" class="edge">
+<title>34&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M1300.5,-1151.7C1300.5,-1143.98 1300.5,-1134.71 1300.5,-1126.11"/>
+<polygon fill="black" stroke="black" points="1304,-1126.1 1300.5,-1116.1 1297,-1126.1 1304,-1126.1"/>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge21" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M1300.5,-1079.7C1300.5,-1071.98 1300.5,-1062.71 1300.5,-1054.11"/>
+<polygon fill="black" stroke="black" points="1304,-1054.1 1300.5,-1044.1 1297,-1054.1 1304,-1054.1"/>
+</g>
+<!-- 37 -->
+<g id="node23" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1504,-972 1139,-972 1139,-936 1504,-936 1504,-972"/>
+<text text-anchor="middle" x="1321.5" y="-950.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 14], reverse=0)</text>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge22" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M1305.69,-1007.7C1308.03,-999.9 1310.85,-990.51 1313.45,-981.83"/>
+<polygon fill="black" stroke="black" points="1316.85,-982.69 1316.37,-972.1 1310.14,-980.68 1316.85,-982.69"/>
+</g>
+<!-- 39 -->
+<g id="node24" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="1395,-900 1288,-900 1288,-864 1395,-864 1395,-900"/>
+<text text-anchor="middle" x="1341.5" y="-878.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 37&#45;&gt;39 -->
+<g id="edge23" class="edge">
+<title>37&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M1326.44,-935.7C1328.67,-927.9 1331.35,-918.51 1333.83,-909.83"/>
+<polygon fill="black" stroke="black" points="1337.23,-910.68 1336.61,-900.1 1330.5,-908.76 1337.23,-910.68"/>
+</g>
+<!-- 39&#45;&gt;40 -->
+<g id="edge24" class="edge">
+<title>39&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1395.36,-868.27C1401.48,-866.83 1407.62,-865.38 1413.5,-864 1476.32,-849.25 1548.94,-832.31 1594.79,-821.64"/>
+<polygon fill="black" stroke="black" points="1595.77,-825 1604.72,-819.33 1594.18,-818.18 1595.77,-825"/>
+</g>
+<!-- 41 -->
+<g id="node26" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="1735.5,-756 1561.5,-756 1561.5,-720 1735.5,-720 1735.5,-756"/>
+<text text-anchor="middle" x="1648.5" y="-734.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 40&#45;&gt;41 -->
+<g id="edge26" class="edge">
+<title>40&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1642.48,-791.7C1643.36,-783.98 1644.42,-774.71 1645.4,-766.11"/>
+<polygon fill="black" stroke="black" points="1648.89,-766.44 1646.55,-756.1 1641.93,-765.64 1648.89,-766.44"/>
+</g>
+<!-- 42 -->
+<g id="node27" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1763,-684 1580,-684 1580,-648 1763,-648 1763,-684"/>
+<text text-anchor="middle" x="1671.5" y="-662.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 41&#45;&gt;42 -->
+<g id="edge27" class="edge">
+<title>41&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1654.19,-719.7C1656.75,-711.9 1659.83,-702.51 1662.68,-693.83"/>
+<polygon fill="black" stroke="black" points="1666.08,-694.7 1665.88,-684.1 1659.43,-692.51 1666.08,-694.7"/>
+</g>
+<!-- 43 -->
+<g id="node28" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="1767.5,-612 1599.5,-612 1599.5,-576 1767.5,-576 1767.5,-612"/>
+<text text-anchor="middle" x="1683.5" y="-590.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge28" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M1674.47,-647.7C1675.79,-639.98 1677.38,-630.71 1678.85,-622.11"/>
+<polygon fill="black" stroke="black" points="1682.33,-622.55 1680.57,-612.1 1675.43,-621.37 1682.33,-622.55"/>
+</g>
+<!-- 44 -->
+<g id="node29" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="1861,-540 1518,-540 1518,-504 1861,-504 1861,-540"/>
+<text text-anchor="middle" x="1689.5" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge29" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M1684.98,-575.7C1685.64,-567.98 1686.44,-558.71 1687.18,-550.11"/>
+<polygon fill="black" stroke="black" points="1690.67,-550.37 1688.03,-540.1 1683.69,-549.77 1690.67,-550.37"/>
+</g>
+<!-- 56 -->
+<g id="node38" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="1873,-468 1704,-468 1704,-432 1873,-432 1873,-468"/>
+<text text-anchor="middle" x="1788.5" y="-446.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 44&#45;&gt;56 -->
+<g id="edge38" class="edge">
+<title>44&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1713.72,-503.88C1726.42,-494.89 1742.16,-483.76 1755.92,-474.03"/>
+<polygon fill="black" stroke="black" points="1758.04,-476.82 1764.19,-468.19 1754,-471.11 1758.04,-476.82"/>
+</g>
+<!-- 49 -->
+<g id="node31" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="2195,-1548 1848,-1548 1848,-1512 2195,-1512 2195,-1548"/>
+<text text-anchor="middle" x="2021.5" y="-1526.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge31" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M2147.65,-1583.88C2124.72,-1574.26 2095.92,-1562.19 2071.63,-1552.01"/>
+<polygon fill="black" stroke="black" points="2072.8,-1548.71 2062.22,-1548.07 2070.09,-1555.16 2072.8,-1548.71"/>
+</g>
+<!-- 51 -->
+<g id="node33" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="2057,-1476 1986,-1476 1986,-1440 2057,-1440 2057,-1476"/>
+<text text-anchor="middle" x="2021.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 49&#45;&gt;51 -->
+<g id="edge32" class="edge">
+<title>49&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M2021.5,-1511.7C2021.5,-1503.98 2021.5,-1494.71 2021.5,-1486.11"/>
+<polygon fill="black" stroke="black" points="2025,-1486.1 2021.5,-1476.1 2018,-1486.1 2025,-1486.1"/>
+</g>
+<!-- 50 -->
+<g id="node32" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="2405.5,-1548 2213.5,-1548 2213.5,-1512 2405.5,-1512 2405.5,-1548"/>
+<text text-anchor="middle" x="2309.5" y="-1526.3" font-family="Times,serif" font-size="14.00">Constant((768,), float32)</text>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge33" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M2239.42,-1511.97C2185.74,-1498.92 2113.28,-1481.31 2066.96,-1470.05"/>
+<polygon fill="black" stroke="black" points="2067.74,-1466.64 2057.2,-1467.68 2066.09,-1473.44 2067.74,-1466.64"/>
+</g>
+<!-- 52 -->
+<g id="node34" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="2162,-1404 1797,-1404 1797,-1368 2162,-1368 2162,-1404"/>
+<text text-anchor="middle" x="1979.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 12, 64], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge34" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M2011.12,-1439.7C2006.24,-1431.56 2000.31,-1421.69 1994.92,-1412.7"/>
+<polygon fill="black" stroke="black" points="1997.91,-1410.88 1989.76,-1404.1 1991.91,-1414.48 1997.91,-1410.88"/>
+</g>
+<!-- 53 -->
+<g id="node35" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="2094,-1332 1865,-1332 1865,-1296 2094,-1296 2094,-1332"/>
+<text text-anchor="middle" x="1979.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge35" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M1979.5,-1367.7C1979.5,-1359.98 1979.5,-1350.71 1979.5,-1342.11"/>
+<polygon fill="black" stroke="black" points="1983,-1342.1 1979.5,-1332.1 1976,-1342.1 1983,-1342.1"/>
+</g>
+<!-- 54 -->
+<g id="node36" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="2150,-1260 1807,-1260 1807,-1224 2150,-1224 2150,-1260"/>
+<text text-anchor="middle" x="1978.5" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge36" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M1979.25,-1295.7C1979.14,-1287.98 1979.01,-1278.71 1978.89,-1270.11"/>
+<polygon fill="black" stroke="black" points="1982.39,-1270.05 1978.74,-1260.1 1975.39,-1270.15 1982.39,-1270.05"/>
+</g>
+<!-- 55 -->
+<g id="node37" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="2080,-1044 1869,-1044 1869,-1008 2080,-1008 2080,-1044"/>
+<text text-anchor="middle" x="1974.5" y="-1022.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge37" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M1978.18,-1223.85C1977.49,-1186.83 1975.85,-1099.18 1975.01,-1054.39"/>
+<polygon fill="black" stroke="black" points="1978.51,-1054.17 1974.82,-1044.23 1971.51,-1054.3 1978.51,-1054.17"/>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge39" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M1956.08,-1007.64C1931.63,-982.62 1891.5,-933.83 1891.5,-883 1891.5,-883 1891.5,-883 1891.5,-593 1891.5,-552.36 1893.93,-537.21 1870.5,-504 1861.77,-491.62 1849.14,-481.42 1836.4,-473.4"/>
+<polygon fill="black" stroke="black" points="1837.99,-470.27 1827.6,-468.19 1834.42,-476.3 1837.99,-470.27"/>
+</g>
+<!-- 57 -->
+<g id="node39" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="1971,-396 1606,-396 1606,-360 1971,-360 1971,-396"/>
+<text text-anchor="middle" x="1788.5" y="-374.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 12, 14, 64], reverse=0)</text>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge40" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M1788.5,-431.7C1788.5,-423.98 1788.5,-414.71 1788.5,-406.11"/>
+<polygon fill="black" stroke="black" points="1792,-406.1 1788.5,-396.1 1785,-406.1 1792,-406.1"/>
+</g>
+<!-- 58 -->
+<g id="node40" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="1903,-324 1674,-324 1674,-288 1903,-288 1903,-324"/>
+<text text-anchor="middle" x="1788.5" y="-302.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge41" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M1788.5,-359.7C1788.5,-351.98 1788.5,-342.71 1788.5,-334.11"/>
+<polygon fill="black" stroke="black" points="1792,-334.1 1788.5,-324.1 1785,-334.1 1792,-334.1"/>
+</g>
+<!-- 59 -->
+<g id="node41" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="1821,-252 1756,-252 1756,-216 1821,-216 1821,-252"/>
+<text text-anchor="middle" x="1788.5" y="-230.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge42" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M1788.5,-287.7C1788.5,-279.98 1788.5,-270.71 1788.5,-262.11"/>
+<polygon fill="black" stroke="black" points="1792,-262.1 1788.5,-252.1 1785,-262.1 1792,-262.1"/>
+</g>
+<!-- 60 -->
+<g id="node42" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="1962,-180 1615,-180 1615,-144 1962,-144 1962,-180"/>
+<text text-anchor="middle" x="1788.5" y="-158.3" font-family="Times,serif" font-size="14.00">reshape(·| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge43" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1788.5,-215.7C1788.5,-207.98 1788.5,-198.71 1788.5,-190.11"/>
+<polygon fill="black" stroke="black" points="1792,-190.1 1788.5,-180.1 1785,-190.1 1792,-190.1"/>
+</g>
+<!-- 61 -->
+<g id="node43" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="1831.5,-108 1745.5,-108 1745.5,-72 1831.5,-72 1831.5,-108"/>
+<text text-anchor="middle" x="1788.5" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 60&#45;&gt;61 -->
+<g id="edge44" class="edge">
+<title>60&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M1788.5,-143.7C1788.5,-135.98 1788.5,-126.71 1788.5,-118.11"/>
+<polygon fill="black" stroke="black" points="1792,-118.1 1788.5,-108.1 1785,-118.1 1792,-118.1"/>
+</g>
+<!-- 62 -->
+<g id="node44" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="1828.5,-36 1748.5,-36 1748.5,0 1828.5,0 1828.5,-36"/>
+<text text-anchor="middle" x="1788.5" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge45" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M1788.5,-71.7C1788.5,-63.98 1788.5,-54.71 1788.5,-46.11"/>
+<polygon fill="black" stroke="black" points="1792,-46.1 1788.5,-36.1 1785,-46.1 1792,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert_layer.svg b/images/bert-pytorch/bert_layer.svg
new file mode 100644
index 0000000..3fca855
--- /dev/null
+++ b/images/bert-pytorch/bert_layer.svg
@@ -0,0 +1,234 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="1433pt" height="793pt"
+ viewBox="0.00 0.00 1432.74 793.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 789)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-789 1428.74,-789 1428.74,4 -4,4"/>
+<text text-anchor="middle" x="712.37" y="-769.8" font-family="Times,serif" font-size="14.00">BertLayer</text>
+<g id="clust1" class="cluster">
+<title>cluster_.9</title>
+<polygon fill="none" stroke="black" points="63,-208 63,-718 1215,-718 1215,-208 63,-208"/>
+<text text-anchor="middle" x="639" y="-702.8" font-family="Times,serif" font-size="14.00">attention (BertAttention)</text>
+</g>
+<g id="clust2" class="cluster">
+<title>cluster_attention..7</title>
+<polygon fill="none" stroke="black" points="83,-352 83,-643 907,-643 907,-352 83,-352"/>
+<text text-anchor="middle" x="495" y="-627.8" font-family="Times,serif" font-size="14.00">attention.self (BertSelfAttention)</text>
+</g>
+<!-- inp_1 -->
+<g id="node1" class="node">
+<title>inp_1</title>
+<ellipse fill="none" stroke="black" cx="172" cy="-744" rx="36" ry="18"/>
+<text text-anchor="middle" x="172" y="-740.3" font-family="Times,serif" font-size="14.00">inp_1</text>
+</g>
+<!-- attention.inp_1 -->
+<g id="node3" class="node">
+<title>attention.inp_1</title>
+<ellipse fill="none" stroke="black" cx="172" cy="-669" rx="81.49" ry="18"/>
+<text text-anchor="middle" x="172" y="-665.3" font-family="Times,serif" font-size="14.00">attention.inp_1</text>
+</g>
+<!-- inp_1&#45;&gt;attention.inp_1 -->
+<g id="edge1" class="edge">
+<title>inp_1&#45;&gt;attention.inp_1</title>
+<path fill="none" stroke="black" d="M172,-725.7C172,-717.25 172,-706.87 172,-697.37"/>
+<polygon fill="black" stroke="black" points="175.5,-697.18 172,-687.18 168.5,-697.18 175.5,-697.18"/>
+</g>
+<!-- inp_attention_mask -->
+<g id="node2" class="node">
+<title>inp_attention_mask</title>
+<ellipse fill="none" stroke="black" cx="1324" cy="-669" rx="100.98" ry="18"/>
+<text text-anchor="middle" x="1324" y="-665.3" font-family="Times,serif" font-size="14.00">inp_attention_mask</text>
+</g>
+<!-- attention.inp_attention_mask -->
+<g id="node4" class="node">
+<title>attention.inp_attention_mask</title>
+<ellipse fill="none" stroke="black" cx="1061" cy="-594" rx="146.47" ry="18"/>
+<text text-anchor="middle" x="1061" y="-590.3" font-family="Times,serif" font-size="14.00">attention.inp_attention_mask</text>
+</g>
+<!-- inp_attention_mask&#45;&gt;attention.inp_attention_mask -->
+<g id="edge2" class="edge">
+<title>inp_attention_mask&#45;&gt;attention.inp_attention_mask</title>
+<path fill="none" stroke="black" d="M1272.04,-653.58C1230.48,-642.04 1171.94,-625.79 1127.05,-613.33"/>
+<polygon fill="black" stroke="black" points="1127.98,-609.96 1117.4,-610.66 1126.1,-616.7 1127.98,-609.96"/>
+</g>
+<!-- attention.self.inp_1 -->
+<g id="node5" class="node">
+<title>attention.self.inp_1</title>
+<ellipse fill="none" stroke="black" cx="206" cy="-594" rx="100.98" ry="18"/>
+<text text-anchor="middle" x="206" y="-590.3" font-family="Times,serif" font-size="14.00">attention.self.inp_1</text>
+</g>
+<!-- attention.inp_1&#45;&gt;attention.self.inp_1 -->
+<g id="edge3" class="edge">
+<title>attention.inp_1&#45;&gt;attention.self.inp_1</title>
+<path fill="none" stroke="black" d="M179.89,-651.07C184.02,-642.2 189.17,-631.13 193.81,-621.18"/>
+<polygon fill="black" stroke="black" points="196.99,-622.63 198.04,-612.09 190.64,-619.68 196.99,-622.63"/>
+</g>
+<!-- attention..8 -->
+<g id="node12" class="node">
+<title>attention..8</title>
+<polygon fill="none" stroke="black" points="329,-324 71,-324 71,-288 329,-288 329,-324"/>
+<text text-anchor="middle" x="200" y="-302.3" font-family="Times,serif" font-size="14.00">attention.output (BertSelfOutput)</text>
+</g>
+<!-- attention.inp_1&#45;&gt;attention..8 -->
+<g id="edge13" class="edge">
+<title>attention.inp_1&#45;&gt;attention..8</title>
+<path fill="none" stroke="black" d="M102.14,-659.76C93.24,-655.92 85.12,-650.52 79,-643 45.27,-601.57 72,-576.42 72,-523 72,-523 72,-523 72,-449 72,-405.78 53.07,-386.58 79,-352 86.3,-342.26 95.95,-334.63 106.62,-328.65"/>
+<polygon fill="black" stroke="black" points="108.45,-331.65 115.79,-324.01 105.29,-325.4 108.45,-331.65"/>
+</g>
+<!-- attention.self.inp_attention_mask -->
+<g id="node6" class="node">
+<title>attention.self.inp_attention_mask</title>
+<ellipse fill="none" stroke="black" cx="733" cy="-522" rx="165.97" ry="18"/>
+<text text-anchor="middle" x="733" y="-518.3" font-family="Times,serif" font-size="14.00">attention.self.inp_attention_mask</text>
+</g>
+<!-- attention.inp_attention_mask&#45;&gt;attention.self.inp_attention_mask -->
+<g id="edge4" class="edge">
+<title>attention.inp_attention_mask&#45;&gt;attention.self.inp_attention_mask</title>
+<path fill="none" stroke="black" d="M991.44,-578.15C939.66,-567.1 868.82,-551.99 814.4,-540.37"/>
+<polygon fill="black" stroke="black" points="814.95,-536.91 804.44,-538.25 813.49,-543.76 814.95,-536.91"/>
+</g>
+<!-- attention.self..111 -->
+<g id="node7" class="node">
+<title>attention.self..111</title>
+<polygon fill="none" stroke="black" points="329,-540 111,-540 111,-504 329,-504 329,-540"/>
+<text text-anchor="middle" x="220" y="-518.3" font-family="Times,serif" font-size="14.00">attention.self.query (Linear)</text>
+</g>
+<!-- attention.self.inp_1&#45;&gt;attention.self..111 -->
+<g id="edge5" class="edge">
+<title>attention.self.inp_1&#45;&gt;attention.self..111</title>
+<path fill="none" stroke="black" d="M209.46,-575.7C211,-567.98 212.86,-558.71 214.58,-550.11"/>
+<polygon fill="black" stroke="black" points="218.05,-550.6 216.58,-540.1 211.19,-549.22 218.05,-550.6"/>
+</g>
+<!-- attention.self..112 -->
+<g id="node8" class="node">
+<title>attention.self..112</title>
+<polygon fill="none" stroke="black" points="549,-540 347,-540 347,-504 549,-504 549,-540"/>
+<text text-anchor="middle" x="448" y="-518.3" font-family="Times,serif" font-size="14.00">attention.self.key (Linear)</text>
+</g>
+<!-- attention.self.inp_1&#45;&gt;attention.self..112 -->
+<g id="edge6" class="edge">
+<title>attention.self.inp_1&#45;&gt;attention.self..112</title>
+<path fill="none" stroke="black" d="M256.44,-578.41C292.06,-568.11 340.24,-554.17 379.36,-542.85"/>
+<polygon fill="black" stroke="black" points="380.47,-546.18 389.1,-540.04 378.52,-539.45 380.47,-546.18"/>
+</g>
+<!-- attention.self..113 -->
+<g id="node9" class="node">
+<title>attention.self..113</title>
+<polygon fill="none" stroke="black" points="306.5,-468 91.5,-468 91.5,-432 306.5,-432 306.5,-468"/>
+<text text-anchor="middle" x="199" y="-446.3" font-family="Times,serif" font-size="14.00">attention.self.value (Linear)</text>
+</g>
+<!-- attention.self.inp_1&#45;&gt;attention.self..113 -->
+<g id="edge7" class="edge">
+<title>attention.self.inp_1&#45;&gt;attention.self..113</title>
+<path fill="none" stroke="black" d="M152.56,-578.71C133.17,-570.67 113.28,-558.4 102,-540 93.64,-526.36 93.84,-517.76 102,-504 109.91,-490.66 122.49,-480.47 136,-472.75"/>
+<polygon fill="black" stroke="black" points="137.76,-475.79 144.98,-468.04 134.5,-469.59 137.76,-475.79"/>
+</g>
+<!-- attention.self..114 -->
+<g id="node10" class="node">
+<title>attention.self..114</title>
+<polygon fill="none" stroke="black" points="571,-468 325,-468 325,-432 571,-432 571,-468"/>
+<text text-anchor="middle" x="448" y="-446.3" font-family="Times,serif" font-size="14.00">attention.self.dropout (Dropout)</text>
+</g>
+<!-- attention.self.inp_attention_mask&#45;&gt;attention.self..114 -->
+<g id="edge9" class="edge">
+<title>attention.self.inp_attention_mask&#45;&gt;attention.self..114</title>
+<path fill="none" stroke="black" d="M669.4,-505.38C627.45,-495.08 572.19,-481.5 527.38,-470.5"/>
+<polygon fill="black" stroke="black" points="527.94,-467.03 517.4,-468.04 526.27,-473.83 527.94,-467.03"/>
+</g>
+<!-- attention.self..111&#45;&gt;attention.self..114 -->
+<g id="edge10" class="edge">
+<title>attention.self..111&#45;&gt;attention.self..114</title>
+<path fill="none" stroke="black" d="M275.48,-503.97C307.76,-494.06 348.61,-481.51 382.46,-471.12"/>
+<polygon fill="black" stroke="black" points="383.74,-474.39 392.27,-468.11 381.68,-467.7 383.74,-474.39"/>
+</g>
+<!-- attention.self..112&#45;&gt;attention.self..114 -->
+<g id="edge8" class="edge">
+<title>attention.self..112&#45;&gt;attention.self..114</title>
+<path fill="none" stroke="black" d="M448,-503.7C448,-495.98 448,-486.71 448,-478.11"/>
+<polygon fill="black" stroke="black" points="451.5,-478.1 448,-468.1 444.5,-478.1 451.5,-478.1"/>
+</g>
+<!-- attention.self.out_0 -->
+<g id="node11" class="node">
+<title>attention.self.out_0</title>
+<ellipse fill="none" stroke="black" cx="200" cy="-378" rx="100.98" ry="18"/>
+<text text-anchor="middle" x="200" y="-374.3" font-family="Times,serif" font-size="14.00">attention.self.out_0</text>
+</g>
+<!-- attention.self..113&#45;&gt;attention.self.out_0 -->
+<g id="edge11" class="edge">
+<title>attention.self..113&#45;&gt;attention.self.out_0</title>
+<path fill="none" stroke="black" d="M199.25,-431.7C199.36,-423.98 199.49,-414.71 199.61,-406.11"/>
+<polygon fill="black" stroke="black" points="203.11,-406.15 199.76,-396.1 196.11,-406.05 203.11,-406.15"/>
+</g>
+<!-- attention.self..114&#45;&gt;attention.self.out_0 -->
+<g id="edge12" class="edge">
+<title>attention.self..114&#45;&gt;attention.self.out_0</title>
+<path fill="none" stroke="black" d="M387.65,-431.97C349.35,-421.16 299.93,-407.21 261.43,-396.34"/>
+<polygon fill="black" stroke="black" points="261.96,-392.85 251.38,-393.5 260.06,-399.59 261.96,-392.85"/>
+</g>
+<!-- attention.self.out_0&#45;&gt;attention..8 -->
+<g id="edge14" class="edge">
+<title>attention.self.out_0&#45;&gt;attention..8</title>
+<path fill="none" stroke="black" d="M200,-359.7C200,-351.98 200,-342.71 200,-334.11"/>
+<polygon fill="black" stroke="black" points="203.5,-334.1 200,-324.1 196.5,-334.1 203.5,-334.1"/>
+</g>
+<!-- attention.out_0 -->
+<g id="node13" class="node">
+<title>attention.out_0</title>
+<ellipse fill="none" stroke="black" cx="200" cy="-234" rx="81.49" ry="18"/>
+<text text-anchor="middle" x="200" y="-230.3" font-family="Times,serif" font-size="14.00">attention.out_0</text>
+</g>
+<!-- attention..8&#45;&gt;attention.out_0 -->
+<g id="edge15" class="edge">
+<title>attention..8&#45;&gt;attention.out_0</title>
+<path fill="none" stroke="black" d="M200,-287.7C200,-279.98 200,-270.71 200,-262.11"/>
+<polygon fill="black" stroke="black" points="203.5,-262.1 200,-252.1 196.5,-262.1 203.5,-262.1"/>
+</g>
+<!-- .10 -->
+<g id="node14" class="node">
+<title>.10</title>
+<polygon fill="none" stroke="black" points="248,-180 0,-180 0,-144 248,-144 248,-180"/>
+<text text-anchor="middle" x="124" y="-158.3" font-family="Times,serif" font-size="14.00">intermediate (BertIntermediate)</text>
+</g>
+<!-- attention.out_0&#45;&gt;.10 -->
+<g id="edge16" class="edge">
+<title>attention.out_0&#45;&gt;.10</title>
+<path fill="none" stroke="black" d="M181.99,-216.41C172.45,-207.63 160.57,-196.68 150.04,-186.99"/>
+<polygon fill="black" stroke="black" points="152.4,-184.4 142.68,-180.2 147.66,-189.55 152.4,-184.4"/>
+</g>
+<!-- .11 -->
+<g id="node15" class="node">
+<title>.11</title>
+<polygon fill="none" stroke="black" points="280.5,-108 119.5,-108 119.5,-72 280.5,-72 280.5,-108"/>
+<text text-anchor="middle" x="200" y="-86.3" font-family="Times,serif" font-size="14.00">output (BertOutput)</text>
+</g>
+<!-- attention.out_0&#45;&gt;.11 -->
+<g id="edge18" class="edge">
+<title>attention.out_0&#45;&gt;.11</title>
+<path fill="none" stroke="black" d="M225.39,-216.83C237.38,-207.65 250.47,-194.99 257,-180 263.39,-165.33 263.39,-158.67 257,-144 252.07,-132.68 243.39,-122.68 234.27,-114.51"/>
+<polygon fill="black" stroke="black" points="236.5,-111.81 226.58,-108.09 232.01,-117.19 236.5,-111.81"/>
+</g>
+<!-- .10&#45;&gt;.11 -->
+<g id="edge17" class="edge">
+<title>.10&#45;&gt;.11</title>
+<path fill="none" stroke="black" d="M142.79,-143.7C152.17,-135.05 163.68,-124.45 173.91,-115.03"/>
+<polygon fill="black" stroke="black" points="176.45,-117.45 181.43,-108.1 171.7,-112.3 176.45,-117.45"/>
+</g>
+<!-- out_0 -->
+<g id="node16" class="node">
+<title>out_0</title>
+<ellipse fill="none" stroke="black" cx="200" cy="-18" rx="36.29" ry="18"/>
+<text text-anchor="middle" x="200" y="-14.3" font-family="Times,serif" font-size="14.00">out_0</text>
+</g>
+<!-- .11&#45;&gt;out_0 -->
+<g id="edge19" class="edge">
+<title>.11&#45;&gt;out_0</title>
+<path fill="none" stroke="black" d="M200,-71.7C200,-63.98 200,-54.71 200,-46.11"/>
+<polygon fill="black" stroke="black" points="203.5,-46.1 200,-36.1 196.5,-46.1 203.5,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/bert_model.svg b/images/bert-pytorch/bert_model.svg
new file mode 100644
index 0000000..0a60e2c
--- /dev/null
+++ b/images/bert-pytorch/bert_model.svg
@@ -0,0 +1,325 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="526pt" height="1294pt"
+ viewBox="0.00 0.00 525.50 1294.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 1290)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-1290 521.5,-1290 521.5,4 -4,4"/>
+<text text-anchor="middle" x="258.75" y="-1270.8" font-family="Times,serif" font-size="14.00">BertModel</text>
+<g id="clust1" class="cluster">
+<title>cluster_.3358</title>
+<polygon fill="none" stroke="black" points="34.5,-136 34.5,-1147 509.5,-1147 509.5,-136 34.5,-136"/>
+<text text-anchor="middle" x="272" y="-1131.8" font-family="Times,serif" font-size="14.00">encoder (BertEncoder)</text>
+</g>
+<!-- inp_input_ids -->
+<g id="node1" class="node">
+<title>inp_input_ids</title>
+<ellipse fill="none" stroke="black" cx="119.5" cy="-1245" rx="72.29" ry="18"/>
+<text text-anchor="middle" x="119.5" y="-1241.3" font-family="Times,serif" font-size="14.00">inp_input_ids</text>
+</g>
+<!-- .3357 -->
+<g id="node3" class="node">
+<title>.3357</title>
+<polygon fill="none" stroke="black" points="239,-1191 0,-1191 0,-1155 239,-1155 239,-1191"/>
+<text text-anchor="middle" x="119.5" y="-1169.3" font-family="Times,serif" font-size="14.00">embeddings (BertEmbeddings)</text>
+</g>
+<!-- inp_input_ids&#45;&gt;.3357 -->
+<g id="edge1" class="edge">
+<title>inp_input_ids&#45;&gt;.3357</title>
+<path fill="none" stroke="black" d="M119.5,-1226.7C119.5,-1218.98 119.5,-1209.71 119.5,-1201.11"/>
+<polygon fill="black" stroke="black" points="123,-1201.1 119.5,-1191.1 116,-1201.1 123,-1201.1"/>
+</g>
+<!-- inp_attention_mask.1 -->
+<g id="node2" class="node">
+<title>inp_attention_mask.1</title>
+<ellipse fill="none" stroke="black" cx="366.5" cy="-1173" rx="109.68" ry="18"/>
+<text text-anchor="middle" x="366.5" y="-1169.3" font-family="Times,serif" font-size="14.00">inp_attention_mask.1</text>
+</g>
+<!-- encoder.inp_attention_mask -->
+<g id="node5" class="node">
+<title>encoder.inp_attention_mask</title>
+<ellipse fill="none" stroke="black" cx="361.5" cy="-1098" rx="139.98" ry="18"/>
+<text text-anchor="middle" x="361.5" y="-1094.3" font-family="Times,serif" font-size="14.00">encoder.inp_attention_mask</text>
+</g>
+<!-- inp_attention_mask.1&#45;&gt;encoder.inp_attention_mask -->
+<g id="edge3" class="edge">
+<title>inp_attention_mask.1&#45;&gt;encoder.inp_attention_mask</title>
+<path fill="none" stroke="black" d="M365.32,-1154.7C364.74,-1146.25 364.03,-1135.87 363.37,-1126.37"/>
+<polygon fill="black" stroke="black" points="366.85,-1125.91 362.68,-1116.18 359.87,-1126.39 366.85,-1125.91"/>
+</g>
+<!-- encoder.inp_26 -->
+<g id="node4" class="node">
+<title>encoder.inp_26</title>
+<ellipse fill="none" stroke="black" cx="123.5" cy="-1098" rx="80.69" ry="18"/>
+<text text-anchor="middle" x="123.5" y="-1094.3" font-family="Times,serif" font-size="14.00">encoder.inp_26</text>
+</g>
+<!-- .3357&#45;&gt;encoder.inp_26 -->
+<g id="edge2" class="edge">
+<title>.3357&#45;&gt;encoder.inp_26</title>
+<path fill="none" stroke="black" d="M120.45,-1154.7C120.91,-1146.25 121.48,-1135.87 122,-1126.37"/>
+<polygon fill="black" stroke="black" points="125.51,-1126.35 122.56,-1116.18 118.52,-1125.97 125.51,-1126.35"/>
+</g>
+<!-- encoder..39 -->
+<g id="node6" class="node">
+<title>encoder..39</title>
+<polygon fill="none" stroke="black" points="294.5,-1044 82.5,-1044 82.5,-1008 294.5,-1008 294.5,-1044"/>
+<text text-anchor="middle" x="188.5" y="-1022.3" font-family="Times,serif" font-size="14.00">encoder.layer.0 (BertLayer)</text>
+</g>
+<!-- encoder.inp_26&#45;&gt;encoder..39 -->
+<g id="edge4" class="edge">
+<title>encoder.inp_26&#45;&gt;encoder..39</title>
+<path fill="none" stroke="black" d="M139.23,-1080.05C147.1,-1071.58 156.77,-1061.17 165.46,-1051.82"/>
+<polygon fill="black" stroke="black" points="168.22,-1053.98 172.46,-1044.28 163.09,-1049.22 168.22,-1053.98"/>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..39 -->
+<g id="edge5" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..39</title>
+<path fill="none" stroke="black" d="M320.94,-1080.59C296.75,-1070.8 265.8,-1058.28 239.9,-1047.8"/>
+<polygon fill="black" stroke="black" points="241.2,-1044.55 230.61,-1044.04 238.57,-1051.04 241.2,-1044.55"/>
+</g>
+<!-- encoder..40 -->
+<g id="node7" class="node">
+<title>encoder..40</title>
+<polygon fill="none" stroke="black" points="334.5,-972 122.5,-972 122.5,-936 334.5,-936 334.5,-972"/>
+<text text-anchor="middle" x="228.5" y="-950.3" font-family="Times,serif" font-size="14.00">encoder.layer.1 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..40 -->
+<g id="edge7" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..40</title>
+<path fill="none" stroke="black" d="M356.56,-1079.85C350.3,-1060.53 337.91,-1029.08 318.5,-1008 307,-995.52 291.89,-985.08 277.33,-976.86"/>
+<polygon fill="black" stroke="black" points="278.84,-973.7 268.38,-972.03 275.52,-979.86 278.84,-973.7"/>
+</g>
+<!-- encoder..41 -->
+<g id="node8" class="node">
+<title>encoder..41</title>
+<polygon fill="none" stroke="black" points="354.5,-900 142.5,-900 142.5,-864 354.5,-864 354.5,-900"/>
+<text text-anchor="middle" x="248.5" y="-878.3" font-family="Times,serif" font-size="14.00">encoder.layer.2 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..41 -->
+<g id="edge9" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..41</title>
+<path fill="none" stroke="black" d="M364.36,-1079.79C368.62,-1048.34 373.24,-980.74 343.5,-936 334.75,-922.84 321.57,-912.62 307.8,-904.83"/>
+<polygon fill="black" stroke="black" points="309.18,-901.6 298.7,-900.05 305.92,-907.79 309.18,-901.6"/>
+</g>
+<!-- encoder..42 -->
+<g id="node9" class="node">
+<title>encoder..42</title>
+<polygon fill="none" stroke="black" points="374.5,-828 162.5,-828 162.5,-792 374.5,-792 374.5,-828"/>
+<text text-anchor="middle" x="268.5" y="-806.3" font-family="Times,serif" font-size="14.00">encoder.layer.3 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..42 -->
+<g id="edge11" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..42</title>
+<path fill="none" stroke="black" d="M367.32,-1079.88C380.1,-1039.01 406.2,-933.91 363.5,-864 355.42,-850.78 342.77,-840.59 329.3,-832.84"/>
+<polygon fill="black" stroke="black" points="330.84,-829.7 320.37,-828.1 327.56,-835.88 330.84,-829.7"/>
+</g>
+<!-- encoder..43 -->
+<g id="node10" class="node">
+<title>encoder..43</title>
+<polygon fill="none" stroke="black" points="384.5,-756 172.5,-756 172.5,-720 384.5,-720 384.5,-756"/>
+<text text-anchor="middle" x="278.5" y="-734.3" font-family="Times,serif" font-size="14.00">encoder.layer.4 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..43 -->
+<g id="edge13" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..43</title>
+<path fill="none" stroke="black" d="M370.57,-1079.82C375.51,-1069.67 381.26,-1056.42 384.5,-1044 394.53,-1005.53 393.5,-994.76 393.5,-955 393.5,-955 393.5,-955 393.5,-881 393.5,-841.2 406.22,-824.68 383.5,-792 374.14,-778.53 360.31,-768.3 345.74,-760.6"/>
+<polygon fill="black" stroke="black" points="347.13,-757.38 336.61,-756.12 344.05,-763.66 347.13,-757.38"/>
+</g>
+<!-- encoder..44 -->
+<g id="node11" class="node">
+<title>encoder..44</title>
+<polygon fill="none" stroke="black" points="384.5,-684 172.5,-684 172.5,-648 384.5,-648 384.5,-684"/>
+<text text-anchor="middle" x="278.5" y="-662.3" font-family="Times,serif" font-size="14.00">encoder.layer.5 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..44 -->
+<g id="edge15" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..44</title>
+<path fill="none" stroke="black" d="M373.2,-1080C379.68,-1069.92 387.43,-1056.67 392.5,-1044 407.59,-1006.26 413.5,-995.64 413.5,-955 413.5,-955 413.5,-955 413.5,-809 413.5,-768.46 418.99,-751.52 393.5,-720 382.46,-706.35 367.13,-696.06 351.25,-688.37"/>
+<polygon fill="black" stroke="black" points="352.42,-685.05 341.86,-684.13 349.53,-691.43 352.42,-685.05"/>
+</g>
+<!-- encoder..45 -->
+<g id="node12" class="node">
+<title>encoder..45</title>
+<polygon fill="none" stroke="black" points="294.5,-612 82.5,-612 82.5,-576 294.5,-576 294.5,-612"/>
+<text text-anchor="middle" x="188.5" y="-590.3" font-family="Times,serif" font-size="14.00">encoder.layer.6 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..45 -->
+<g id="edge16" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..45</title>
+<path fill="none" stroke="black" d="M345.42,-1079.8C336.69,-1069.86 326.15,-1056.82 318.5,-1044 309.62,-1029.12 316.96,-1018.92 303.5,-1008 236.76,-953.85 171.94,-1035.02 113.5,-972 86.43,-942.81 103.5,-922.8 103.5,-883 103.5,-883 103.5,-883 103.5,-737 103.5,-690.45 137.54,-645.95 162.54,-619.55"/>
+<polygon fill="black" stroke="black" points="165.16,-621.87 169.64,-612.27 160.15,-616.98 165.16,-621.87"/>
+</g>
+<!-- encoder..46 -->
+<g id="node13" class="node">
+<title>encoder..46</title>
+<polygon fill="none" stroke="black" points="284.5,-540 72.5,-540 72.5,-504 284.5,-504 284.5,-540"/>
+<text text-anchor="middle" x="178.5" y="-518.3" font-family="Times,serif" font-size="14.00">encoder.layer.7 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..46 -->
+<g id="edge19" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..46</title>
+<path fill="none" stroke="black" d="M264.82,-1084.97C186.39,-1074.27 86.72,-1058.18 73.5,-1044 46.36,-1014.88 63.5,-994.8 63.5,-955 63.5,-955 63.5,-955 63.5,-665 63.5,-625.2 50.78,-608.68 73.5,-576 82.86,-562.53 96.69,-552.3 111.26,-544.6"/>
+<polygon fill="black" stroke="black" points="112.95,-547.66 120.39,-540.12 109.87,-541.38 112.95,-547.66"/>
+</g>
+<!-- encoder..47 -->
+<g id="node14" class="node">
+<title>encoder..47</title>
+<polygon fill="none" stroke="black" points="284.5,-468 72.5,-468 72.5,-432 284.5,-432 284.5,-468"/>
+<text text-anchor="middle" x="178.5" y="-446.3" font-family="Times,serif" font-size="14.00">encoder.layer.8 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..47 -->
+<g id="edge21" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..47</title>
+<path fill="none" stroke="black" d="M260.15,-1085.54C176.34,-1075.09 68.82,-1059.02 54.5,-1044 27,-1015.15 43.5,-994.86 43.5,-955 43.5,-955 43.5,-955 43.5,-593 43.5,-552.46 38.01,-535.52 63.5,-504 74.54,-490.35 89.87,-480.06 105.75,-472.37"/>
+<polygon fill="black" stroke="black" points="107.47,-475.43 115.14,-468.13 104.58,-469.05 107.47,-475.43"/>
+</g>
+<!-- encoder..48 -->
+<g id="node15" class="node">
+<title>encoder..48</title>
+<polygon fill="none" stroke="black" points="434.5,-396 222.5,-396 222.5,-360 434.5,-360 434.5,-396"/>
+<text text-anchor="middle" x="328.5" y="-374.3" font-family="Times,serif" font-size="14.00">encoder.layer.9 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..48 -->
+<g id="edge23" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..48</title>
+<path fill="none" stroke="black" d="M377.53,-1079.94C386.15,-1070.04 396.44,-1057 403.5,-1044 423.43,-1007.32 433.5,-996.74 433.5,-955 433.5,-955 433.5,-955 433.5,-521 433.5,-471.46 391.68,-427.9 360.82,-402.49"/>
+<polygon fill="black" stroke="black" points="362.85,-399.63 352.85,-396.13 358.48,-405.1 362.85,-399.63"/>
+</g>
+<!-- encoder..49 -->
+<g id="node16" class="node">
+<title>encoder..49</title>
+<polygon fill="none" stroke="black" points="454,-324 233,-324 233,-288 454,-288 454,-324"/>
+<text text-anchor="middle" x="343.5" y="-302.3" font-family="Times,serif" font-size="14.00">encoder.layer.10 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..49 -->
+<g id="edge24" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..49</title>
+<path fill="none" stroke="black" d="M383.35,-1080.2C394.72,-1070.58 408.15,-1057.71 417.5,-1044 441.54,-1008.75 453.5,-997.67 453.5,-955 453.5,-955 453.5,-955 453.5,-449 453.5,-409.2 465.9,-392.9 443.5,-360 434.39,-346.62 420.78,-336.36 406.52,-328.6"/>
+<polygon fill="black" stroke="black" points="408.09,-325.47 397.59,-324.08 404.94,-331.72 408.09,-325.47"/>
+</g>
+<!-- encoder..50 -->
+<g id="node17" class="node">
+<title>encoder..50</title>
+<polygon fill="none" stroke="black" points="478,-252 257,-252 257,-216 478,-216 478,-252"/>
+<text text-anchor="middle" x="367.5" y="-230.3" font-family="Times,serif" font-size="14.00">encoder.layer.11 (BertLayer)</text>
+</g>
+<!-- encoder.inp_attention_mask&#45;&gt;encoder..50 -->
+<g id="edge27" class="edge">
+<title>encoder.inp_attention_mask&#45;&gt;encoder..50</title>
+<path fill="none" stroke="black" d="M390.96,-1080.29C405.25,-1071.04 421.8,-1058.49 433.5,-1044 460.75,-1010.26 473.5,-998.37 473.5,-955 473.5,-955 473.5,-955 473.5,-377 473.5,-337.14 484.76,-321.06 462.5,-288 453.68,-274.89 440.47,-264.69 426.69,-256.89"/>
+<polygon fill="black" stroke="black" points="428.08,-253.67 417.6,-252.11 424.82,-259.86 428.08,-253.67"/>
+</g>
+<!-- encoder..39&#45;&gt;encoder..40 -->
+<g id="edge6" class="edge">
+<title>encoder..39&#45;&gt;encoder..40</title>
+<path fill="none" stroke="black" d="M198.39,-1007.7C202.99,-999.64 208.56,-989.89 213.65,-980.98"/>
+<polygon fill="black" stroke="black" points="216.8,-982.52 218.73,-972.1 210.73,-979.05 216.8,-982.52"/>
+</g>
+<!-- encoder..40&#45;&gt;encoder..41 -->
+<g id="edge8" class="edge">
+<title>encoder..40&#45;&gt;encoder..41</title>
+<path fill="none" stroke="black" d="M233.44,-935.7C235.67,-927.9 238.35,-918.51 240.83,-909.83"/>
+<polygon fill="black" stroke="black" points="244.23,-910.68 243.61,-900.1 237.5,-908.76 244.23,-910.68"/>
+</g>
+<!-- encoder..41&#45;&gt;encoder..42 -->
+<g id="edge10" class="edge">
+<title>encoder..41&#45;&gt;encoder..42</title>
+<path fill="none" stroke="black" d="M253.44,-863.7C255.67,-855.9 258.35,-846.51 260.83,-837.83"/>
+<polygon fill="black" stroke="black" points="264.23,-838.68 263.61,-828.1 257.5,-836.76 264.23,-838.68"/>
+</g>
+<!-- encoder..42&#45;&gt;encoder..43 -->
+<g id="edge12" class="edge">
+<title>encoder..42&#45;&gt;encoder..43</title>
+<path fill="none" stroke="black" d="M270.97,-791.7C272.07,-783.98 273.4,-774.71 274.63,-766.11"/>
+<polygon fill="black" stroke="black" points="278.11,-766.5 276.06,-756.1 271.18,-765.51 278.11,-766.5"/>
+</g>
+<!-- encoder..43&#45;&gt;encoder..44 -->
+<g id="edge14" class="edge">
+<title>encoder..43&#45;&gt;encoder..44</title>
+<path fill="none" stroke="black" d="M278.5,-719.7C278.5,-711.98 278.5,-702.71 278.5,-694.11"/>
+<polygon fill="black" stroke="black" points="282,-694.1 278.5,-684.1 275,-694.1 282,-694.1"/>
+</g>
+<!-- encoder..44&#45;&gt;encoder..45 -->
+<g id="edge17" class="edge">
+<title>encoder..44&#45;&gt;encoder..45</title>
+<path fill="none" stroke="black" d="M256.25,-647.7C244.92,-638.88 230.97,-628.03 218.68,-618.47"/>
+<polygon fill="black" stroke="black" points="220.53,-615.48 210.49,-612.1 216.24,-621.01 220.53,-615.48"/>
+</g>
+<!-- encoder..45&#45;&gt;encoder..46 -->
+<g id="edge18" class="edge">
+<title>encoder..45&#45;&gt;encoder..46</title>
+<path fill="none" stroke="black" d="M186.03,-575.7C184.93,-567.98 183.6,-558.71 182.37,-550.11"/>
+<polygon fill="black" stroke="black" points="185.82,-549.51 180.94,-540.1 178.89,-550.5 185.82,-549.51"/>
+</g>
+<!-- encoder..46&#45;&gt;encoder..47 -->
+<g id="edge20" class="edge">
+<title>encoder..46&#45;&gt;encoder..47</title>
+<path fill="none" stroke="black" d="M178.5,-503.7C178.5,-495.98 178.5,-486.71 178.5,-478.11"/>
+<polygon fill="black" stroke="black" points="182,-478.1 178.5,-468.1 175,-478.1 182,-478.1"/>
+</g>
+<!-- encoder..47&#45;&gt;encoder..48 -->
+<g id="edge22" class="edge">
+<title>encoder..47&#45;&gt;encoder..48</title>
+<path fill="none" stroke="black" d="M215.19,-431.88C235.52,-422.39 260.97,-410.51 282.6,-400.42"/>
+<polygon fill="black" stroke="black" points="284.34,-403.47 291.93,-396.07 281.38,-397.13 284.34,-403.47"/>
+</g>
+<!-- encoder..48&#45;&gt;encoder..49 -->
+<g id="edge25" class="edge">
+<title>encoder..48&#45;&gt;encoder..49</title>
+<path fill="none" stroke="black" d="M332.21,-359.7C333.86,-351.98 335.85,-342.71 337.69,-334.11"/>
+<polygon fill="black" stroke="black" points="341.16,-334.62 339.83,-324.1 334.32,-333.15 341.16,-334.62"/>
+</g>
+<!-- encoder..49&#45;&gt;encoder..50 -->
+<g id="edge26" class="edge">
+<title>encoder..49&#45;&gt;encoder..50</title>
+<path fill="none" stroke="black" d="M349.43,-287.7C352.11,-279.9 355.33,-270.51 358.3,-261.83"/>
+<polygon fill="black" stroke="black" points="361.7,-262.7 361.64,-252.1 355.08,-260.43 361.7,-262.7"/>
+</g>
+<!-- encoder.out_0 -->
+<g id="node18" class="node">
+<title>encoder.out_0</title>
+<ellipse fill="none" stroke="black" cx="367.5" cy="-162" rx="75.29" ry="18"/>
+<text text-anchor="middle" x="367.5" y="-158.3" font-family="Times,serif" font-size="14.00">encoder.out_0</text>
+</g>
+<!-- encoder..50&#45;&gt;encoder.out_0 -->
+<g id="edge28" class="edge">
+<title>encoder..50&#45;&gt;encoder.out_0</title>
+<path fill="none" stroke="black" d="M367.5,-215.7C367.5,-207.98 367.5,-198.71 367.5,-190.11"/>
+<polygon fill="black" stroke="black" points="371,-190.1 367.5,-180.1 364,-190.1 371,-190.1"/>
+</g>
+<!-- .3359 -->
+<g id="node19" class="node">
+<title>.3359</title>
+<polygon fill="none" stroke="black" points="391,-108 238,-108 238,-72 391,-72 391,-108"/>
+<text text-anchor="middle" x="314.5" y="-86.3" font-family="Times,serif" font-size="14.00">pooler (BertPooler)</text>
+</g>
+<!-- encoder.out_0&#45;&gt;.3359 -->
+<g id="edge29" class="edge">
+<title>encoder.out_0&#45;&gt;.3359</title>
+<path fill="none" stroke="black" d="M354.67,-144.05C348.38,-135.75 340.68,-125.58 333.71,-116.38"/>
+<polygon fill="black" stroke="black" points="336.41,-114.14 327.58,-108.28 330.83,-118.36 336.41,-114.14"/>
+</g>
+<!-- out_0 -->
+<g id="node20" class="node">
+<title>out_0</title>
+<ellipse fill="none" stroke="black" cx="366.5" cy="-18" rx="36.29" ry="18"/>
+<text text-anchor="middle" x="366.5" y="-14.3" font-family="Times,serif" font-size="14.00">out_0</text>
+</g>
+<!-- encoder.out_0&#45;&gt;out_0 -->
+<g id="edge31" class="edge">
+<title>encoder.out_0&#45;&gt;out_0</title>
+<path fill="none" stroke="black" d="M381.48,-144.04C388.69,-134.18 396.75,-121.13 400.5,-108 404.9,-92.62 405,-87.35 400.5,-72 397.53,-61.88 391.97,-51.86 386.2,-43.3"/>
+<polygon fill="black" stroke="black" points="388.88,-41.04 380.2,-34.97 383.2,-45.13 388.88,-41.04"/>
+</g>
+<!-- .3359&#45;&gt;out_0 -->
+<g id="edge30" class="edge">
+<title>.3359&#45;&gt;out_0</title>
+<path fill="none" stroke="black" d="M327.35,-71.7C333.71,-63.14 341.49,-52.67 348.43,-43.33"/>
+<polygon fill="black" stroke="black" points="351.35,-45.26 354.51,-35.14 345.73,-41.08 351.35,-45.26"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/pytorch-tvm-training_20_0.svg b/images/bert-pytorch/pytorch-tvm-training_20_0.svg
new file mode 100644
index 0000000..4521fb8
--- /dev/null
+++ b/images/bert-pytorch/pytorch-tvm-training_20_0.svg
@@ -0,0 +1,1237 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="3671pt" height="3716pt"
+ viewBox="0.00 0.00 3671.03 3716.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 3712)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-3712 3667.03,-3712 3667.03,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="2012.5" cy="-3546" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="2012.5" y="-3542.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 29 -->
+<g id="node19" class="node">
+<title>29</title>
+<polygon fill="none" stroke="black" points="2038.5,-3492 1590.5,-3492 1590.5,-3456 2038.5,-3456 2038.5,-3492"/>
+<text text-anchor="middle" x="1814.5" y="-3470.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;29 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;29</title>
+<path fill="none" stroke="black" d="M1966.08,-3528.59C1938.02,-3518.67 1902.03,-3505.95 1872.16,-3495.39"/>
+<polygon fill="black" stroke="black" points="1873.29,-3492.07 1862.7,-3492.04 1870.96,-3498.67 1873.29,-3492.07"/>
+</g>
+<!-- 44 -->
+<g id="node29" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="2504.5,-3492 2056.5,-3492 2056.5,-3456 2504.5,-3456 2504.5,-3492"/>
+<text text-anchor="middle" x="2280.5" y="-3470.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;44 -->
+<g id="edge13" class="edge">
+<title>0&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M2073.31,-3529.12C2112.49,-3518.88 2163.7,-3505.51 2205.41,-3494.61"/>
+<polygon fill="black" stroke="black" points="2206.45,-3497.96 2215.25,-3492.04 2204.69,-3491.19 2206.45,-3497.96"/>
+</g>
+<!-- 72 -->
+<g id="node49" class="node">
+<title>72</title>
+<polygon fill="none" stroke="black" points="1572.5,-3492 1124.5,-3492 1124.5,-3456 1572.5,-3456 1572.5,-3492"/>
+<text text-anchor="middle" x="1348.5" y="-3470.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;72 -->
+<g id="edge37" class="edge">
+<title>0&#45;&gt;72</title>
+<path fill="none" stroke="black" d="M1896.46,-3532.77C1792.43,-3521.8 1638.42,-3505.56 1520.15,-3493.1"/>
+<polygon fill="black" stroke="black" points="1520.21,-3489.58 1509.9,-3492.01 1519.48,-3496.54 1520.21,-3489.58"/>
+</g>
+<!-- 106 -->
+<g id="node74" class="node">
+<title>106</title>
+<polygon fill="none" stroke="black" points="2856,-1476 2785,-1476 2785,-1440 2856,-1440 2856,-1476"/>
+<text text-anchor="middle" x="2820.5" y="-1454.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 0&#45;&gt;106 -->
+<g id="edge68" class="edge">
+<title>0&#45;&gt;106</title>
+<path fill="none" stroke="black" d="M2161.62,-3537.11C2367.88,-3525.77 2717.95,-3504.82 2743.5,-3492 2790.25,-3468.55 2820.5,-3455.3 2820.5,-3403 2820.5,-3403 2820.5,-3403 2820.5,-1601 2820.5,-1561 2820.5,-1514.65 2820.5,-1486.08"/>
+<polygon fill="black" stroke="black" points="2824,-1486.05 2820.5,-1476.05 2817,-1486.05 2824,-1486.05"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1743.5" cy="-2682" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1743.5" y="-2678.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 65 -->
+<g id="node44" class="node">
+<title>65</title>
+<polygon fill="none" stroke="black" points="2024,-2628 1953,-2628 1953,-2592 2024,-2592 2024,-2628"/>
+<text text-anchor="middle" x="1988.5" y="-2606.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 1&#45;&gt;65 -->
+<g id="edge32" class="edge">
+<title>1&#45;&gt;65</title>
+<path fill="none" stroke="black" d="M1800.94,-2664.59C1844.23,-2652.22 1902.79,-2635.49 1942.89,-2624.03"/>
+<polygon fill="black" stroke="black" points="1944.01,-2627.35 1952.67,-2621.24 1942.09,-2620.62 1944.01,-2627.35"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="1000.5" cy="-3690" rx="265.65" ry="18"/>
+<text text-anchor="middle" x="1000.5" y="-3686.3" font-family="Times,serif" font-size="14.00">attention.self.query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 30 -->
+<g id="node20" class="node">
+<title>30</title>
+<polygon fill="none" stroke="black" points="1097,-3636 904,-3636 904,-3600 1097,-3600 1097,-3636"/>
+<text text-anchor="middle" x="1000.5" y="-3614.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 2&#45;&gt;30 -->
+<g id="edge2" class="edge">
+<title>2&#45;&gt;30</title>
+<path fill="none" stroke="black" d="M1000.5,-3671.7C1000.5,-3663.98 1000.5,-3654.71 1000.5,-3646.11"/>
+<polygon fill="black" stroke="black" points="1004,-3646.1 1000.5,-3636.1 997,-3646.1 1004,-3646.1"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="1637.5" cy="-3330" rx="232.86" ry="18"/>
+<text text-anchor="middle" x="1637.5" y="-3326.3" font-family="Times,serif" font-size="14.00">attention.self.query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 37 -->
+<g id="node25" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="1673,-3276 1602,-3276 1602,-3240 1673,-3240 1673,-3276"/>
+<text text-anchor="middle" x="1637.5" y="-3254.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 3&#45;&gt;37 -->
+<g id="edge9" class="edge">
+<title>3&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M1637.5,-3311.7C1637.5,-3303.98 1637.5,-3294.71 1637.5,-3286.11"/>
+<polygon fill="black" stroke="black" points="1641,-3286.1 1637.5,-3276.1 1634,-3286.1 1641,-3286.1"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="2628.5" cy="-3690" rx="254.55" ry="18"/>
+<text text-anchor="middle" x="2628.5" y="-3686.3" font-family="Times,serif" font-size="14.00">attention.self.key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 45 -->
+<g id="node30" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="2725,-3636 2532,-3636 2532,-3600 2725,-3600 2725,-3636"/>
+<text text-anchor="middle" x="2628.5" y="-3614.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 4&#45;&gt;45 -->
+<g id="edge14" class="edge">
+<title>4&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M2628.5,-3671.7C2628.5,-3663.98 2628.5,-3654.71 2628.5,-3646.11"/>
+<polygon fill="black" stroke="black" points="2632,-3646.1 2628.5,-3636.1 2625,-3646.1 2632,-3646.1"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="2109.5" cy="-3330" rx="221.76" ry="18"/>
+<text text-anchor="middle" x="2109.5" y="-3326.3" font-family="Times,serif" font-size="14.00">attention.self.key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 52 -->
+<g id="node35" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="2145,-3276 2074,-3276 2074,-3240 2145,-3240 2145,-3276"/>
+<text text-anchor="middle" x="2109.5" y="-3254.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;52 -->
+<g id="edge21" class="edge">
+<title>5&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M2109.5,-3311.7C2109.5,-3303.98 2109.5,-3294.71 2109.5,-3286.11"/>
+<polygon fill="black" stroke="black" points="2113,-3286.1 2109.5,-3276.1 2106,-3286.1 2113,-3286.1"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="336.5" cy="-3690" rx="265.35" ry="18"/>
+<text text-anchor="middle" x="336.5" y="-3686.3" font-family="Times,serif" font-size="14.00">attention.self.value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 73 -->
+<g id="node50" class="node">
+<title>73</title>
+<polygon fill="none" stroke="black" points="433,-3636 240,-3636 240,-3600 433,-3600 433,-3636"/>
+<text text-anchor="middle" x="336.5" y="-3614.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;73 -->
+<g id="edge38" class="edge">
+<title>6&#45;&gt;73</title>
+<path fill="none" stroke="black" d="M336.5,-3671.7C336.5,-3663.98 336.5,-3654.71 336.5,-3646.11"/>
+<polygon fill="black" stroke="black" points="340,-3646.1 336.5,-3636.1 333,-3646.1 340,-3646.1"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="693.5" cy="-3330" rx="232.06" ry="18"/>
+<text text-anchor="middle" x="693.5" y="-3326.3" font-family="Times,serif" font-size="14.00">attention.self.value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 80 -->
+<g id="node55" class="node">
+<title>80</title>
+<polygon fill="none" stroke="black" points="729,-3276 658,-3276 658,-3240 729,-3240 729,-3276"/>
+<text text-anchor="middle" x="693.5" y="-3254.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;80 -->
+<g id="edge45" class="edge">
+<title>7&#45;&gt;80</title>
+<path fill="none" stroke="black" d="M693.5,-3311.7C693.5,-3303.98 693.5,-3294.71 693.5,-3286.11"/>
+<polygon fill="black" stroke="black" points="697,-3286.1 693.5,-3276.1 690,-3286.1 697,-3286.1"/>
+</g>
+<!-- 8 -->
+<g id="node9" class="node">
+<title>8</title>
+<ellipse fill="none" stroke="black" cx="1578.5" cy="-2106" rx="282.15" ry="18"/>
+<text text-anchor="middle" x="1578.5" y="-2102.3" font-family="Times,serif" font-size="14.00">attention.output.dense.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 96 -->
+<g id="node66" class="node">
+<title>96</title>
+<polygon fill="none" stroke="black" points="1675,-2052 1482,-2052 1482,-2016 1675,-2016 1675,-2052"/>
+<text text-anchor="middle" x="1578.5" y="-2030.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 8&#45;&gt;96 -->
+<g id="edge57" class="edge">
+<title>8&#45;&gt;96</title>
+<path fill="none" stroke="black" d="M1578.5,-2087.7C1578.5,-2079.98 1578.5,-2070.71 1578.5,-2062.11"/>
+<polygon fill="black" stroke="black" points="1582,-2062.1 1578.5,-2052.1 1575,-2062.1 1582,-2062.1"/>
+</g>
+<!-- 9 -->
+<g id="node10" class="node">
+<title>9</title>
+<ellipse fill="none" stroke="black" cx="2542.5" cy="-1746" rx="248.86" ry="18"/>
+<text text-anchor="middle" x="2542.5" y="-1742.3" font-family="Times,serif" font-size="14.00">attention.output.dense.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 103 -->
+<g id="node71" class="node">
+<title>103</title>
+<polygon fill="none" stroke="black" points="2578,-1692 2507,-1692 2507,-1656 2578,-1656 2578,-1692"/>
+<text text-anchor="middle" x="2542.5" y="-1670.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 9&#45;&gt;103 -->
+<g id="edge64" class="edge">
+<title>9&#45;&gt;103</title>
+<path fill="none" stroke="black" d="M2542.5,-1727.7C2542.5,-1719.98 2542.5,-1710.71 2542.5,-1702.11"/>
+<polygon fill="black" stroke="black" points="2546,-1702.1 2542.5,-1692.1 2539,-1702.1 2546,-1702.1"/>
+</g>
+<!-- 10 -->
+<g id="node11" class="node">
+<title>10</title>
+<ellipse fill="none" stroke="black" cx="1913.5" cy="-1458" rx="286.75" ry="18"/>
+<text text-anchor="middle" x="1913.5" y="-1454.3" font-family="Times,serif" font-size="14.00">attention.output.LayerNorm.weight: Tensor[(768,), float32]</text>
+</g>
+<!-- 107 -->
+<g id="node75" class="node">
+<title>107</title>
+<polygon fill="none" stroke="black" points="2739.5,-1404 2261.5,-1404 2261.5,-1368 2739.5,-1368 2739.5,-1404"/>
+<text text-anchor="middle" x="2500.5" y="-1382.3" font-family="Times,serif" font-size="14.00">nn.layer_norm(·, ·, ·| axis=&#45;1, epsilon=1e&#45;12, center=1, scale=1)</text>
+</g>
+<!-- 10&#45;&gt;107 -->
+<g id="edge70" class="edge">
+<title>10&#45;&gt;107</title>
+<path fill="none" stroke="black" d="M2040.51,-1441.85C2130.7,-1431.1 2252.23,-1416.61 2347.72,-1405.22"/>
+<polygon fill="black" stroke="black" points="2348.24,-1408.68 2357.76,-1404.02 2347.41,-1401.73 2348.24,-1408.68"/>
+</g>
+<!-- 11 -->
+<g id="node12" class="node">
+<title>11</title>
+<ellipse fill="none" stroke="black" cx="2492.5" cy="-1458" rx="274.05" ry="18"/>
+<text text-anchor="middle" x="2492.5" y="-1454.3" font-family="Times,serif" font-size="14.00">attention.output.LayerNorm.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 11&#45;&gt;107 -->
+<g id="edge71" class="edge">
+<title>11&#45;&gt;107</title>
+<path fill="none" stroke="black" d="M2494.48,-1439.7C2495.36,-1431.98 2496.42,-1422.71 2497.4,-1414.11"/>
+<polygon fill="black" stroke="black" points="2500.89,-1414.44 2498.55,-1404.1 2493.93,-1413.64 2500.89,-1414.44"/>
+</g>
+<!-- 12 -->
+<g id="node13" class="node">
+<title>12</title>
+<ellipse fill="none" stroke="black" cx="3120.5" cy="-1530" rx="271.85" ry="18"/>
+<text text-anchor="middle" x="3120.5" y="-1526.3" font-family="Times,serif" font-size="14.00">intermediate.dense.weight: Tensor[(3072, 768), float32]</text>
+</g>
+<!-- 110 -->
+<g id="node77" class="node">
+<title>110</title>
+<polygon fill="none" stroke="black" points="3214,-1476 3021,-1476 3021,-1440 3214,-1440 3214,-1476"/>
+<text text-anchor="middle" x="3117.5" y="-1454.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 12&#45;&gt;110 -->
+<g id="edge73" class="edge">
+<title>12&#45;&gt;110</title>
+<path fill="none" stroke="black" d="M3119.76,-1511.7C3119.43,-1503.98 3119.03,-1494.71 3118.66,-1486.11"/>
+<polygon fill="black" stroke="black" points="3122.16,-1485.95 3118.23,-1476.1 3115.16,-1486.25 3122.16,-1485.95"/>
+</g>
+<!-- 13 -->
+<g id="node14" class="node">
+<title>13</title>
+<ellipse fill="none" stroke="black" cx="3424.5" cy="-1170" rx="238.56" ry="18"/>
+<text text-anchor="middle" x="3424.5" y="-1166.3" font-family="Times,serif" font-size="14.00">intermediate.dense.bias: Tensor[(3072,), float32]</text>
+</g>
+<!-- 117 -->
+<g id="node82" class="node">
+<title>117</title>
+<polygon fill="none" stroke="black" points="3179,-1116 3108,-1116 3108,-1080 3179,-1080 3179,-1116"/>
+<text text-anchor="middle" x="3143.5" y="-1094.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 13&#45;&gt;117 -->
+<g id="edge80" class="edge">
+<title>13&#45;&gt;117</title>
+<path fill="none" stroke="black" d="M3358.97,-1152.68C3306.7,-1139.65 3234.84,-1121.75 3188.77,-1110.28"/>
+<polygon fill="black" stroke="black" points="3189.62,-1106.88 3179.07,-1107.86 3187.93,-1113.67 3189.62,-1106.88"/>
+</g>
+<!-- 14 -->
+<g id="node15" class="node">
+<title>14</title>
+<ellipse fill="none" stroke="black" cx="2770.5" cy="-882" rx="242.36" ry="18"/>
+<text text-anchor="middle" x="2770.5" y="-878.3" font-family="Times,serif" font-size="14.00">output.dense.weight: Tensor[(768, 3072), float32]</text>
+</g>
+<!-- 129 -->
+<g id="node89" class="node">
+<title>129</title>
+<polygon fill="none" stroke="black" points="2862,-828 2669,-828 2669,-792 2862,-792 2862,-828"/>
+<text text-anchor="middle" x="2765.5" y="-806.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 14&#45;&gt;129 -->
+<g id="edge88" class="edge">
+<title>14&#45;&gt;129</title>
+<path fill="none" stroke="black" d="M2769.26,-863.7C2768.71,-855.98 2768.05,-846.71 2767.44,-838.11"/>
+<polygon fill="black" stroke="black" points="2770.93,-837.83 2766.72,-828.1 2763.94,-838.33 2770.93,-837.83"/>
+</g>
+<!-- 15 -->
+<g id="node16" class="node">
+<title>15</title>
+<ellipse fill="none" stroke="black" cx="3203.5" cy="-522" rx="203.36" ry="18"/>
+<text text-anchor="middle" x="3203.5" y="-518.3" font-family="Times,serif" font-size="14.00">output.dense.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 136 -->
+<g id="node94" class="node">
+<title>136</title>
+<polygon fill="none" stroke="black" points="2796,-468 2725,-468 2725,-432 2796,-432 2796,-468"/>
+<text text-anchor="middle" x="2760.5" y="-446.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 15&#45;&gt;136 -->
+<g id="edge95" class="edge">
+<title>15&#45;&gt;136</title>
+<path fill="none" stroke="black" d="M3108.74,-506.03C3015.47,-491.29 2877.01,-469.41 2806.05,-458.2"/>
+<polygon fill="black" stroke="black" points="2806.51,-454.73 2796.09,-456.62 2805.42,-461.64 2806.51,-454.73"/>
+</g>
+<!-- 16 -->
+<g id="node17" class="node">
+<title>16</title>
+<ellipse fill="none" stroke="black" cx="1788.5" cy="-234" rx="241.26" ry="18"/>
+<text text-anchor="middle" x="1788.5" y="-230.3" font-family="Times,serif" font-size="14.00">output.LayerNorm.weight: Tensor[(768,), float32]</text>
+</g>
+<!-- 140 -->
+<g id="node98" class="node">
+<title>140</title>
+<polygon fill="none" stroke="black" points="2515.5,-180 2037.5,-180 2037.5,-144 2515.5,-144 2515.5,-180"/>
+<text text-anchor="middle" x="2276.5" y="-158.3" font-family="Times,serif" font-size="14.00">nn.layer_norm(·, ·, ·| axis=&#45;1, epsilon=1e&#45;12, center=1, scale=1)</text>
+</g>
+<!-- 16&#45;&gt;140 -->
+<g id="edge101" class="edge">
+<title>16&#45;&gt;140</title>
+<path fill="none" stroke="black" d="M1894.39,-217.81C1968.85,-207.13 2068.91,-192.78 2147.95,-181.44"/>
+<polygon fill="black" stroke="black" points="2148.57,-184.89 2157.97,-180 2147.58,-177.96 2148.57,-184.89"/>
+</g>
+<!-- 17 -->
+<g id="node18" class="node">
+<title>17</title>
+<ellipse fill="none" stroke="black" cx="2276.5" cy="-234" rx="228.56" ry="18"/>
+<text text-anchor="middle" x="2276.5" y="-230.3" font-family="Times,serif" font-size="14.00">output.LayerNorm.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 17&#45;&gt;140 -->
+<g id="edge102" class="edge">
+<title>17&#45;&gt;140</title>
+<path fill="none" stroke="black" d="M2276.5,-215.7C2276.5,-207.98 2276.5,-198.71 2276.5,-190.11"/>
+<polygon fill="black" stroke="black" points="2280,-190.1 2276.5,-180.1 2273,-190.1 2280,-190.1"/>
+</g>
+<!-- 34 -->
+<g id="node23" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="1250,-3420 1081,-3420 1081,-3384 1250,-3384 1250,-3420"/>
+<text text-anchor="middle" x="1165.5" y="-3398.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 29&#45;&gt;34 -->
+<g id="edge5" class="edge">
+<title>29&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M1656.57,-3455.97C1532.74,-3442.61 1364.53,-3424.47 1260.59,-3413.26"/>
+<polygon fill="black" stroke="black" points="1260.74,-3409.75 1250.42,-3412.16 1259.99,-3416.71 1260.74,-3409.75"/>
+</g>
+<!-- 32 -->
+<g id="node21" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="1231,-3564 770,-3564 770,-3528 1231,-3528 1231,-3564"/>
+<text text-anchor="middle" x="1000.5" y="-3542.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 30&#45;&gt;32 -->
+<g id="edge3" class="edge">
+<title>30&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M1000.5,-3599.7C1000.5,-3591.98 1000.5,-3582.71 1000.5,-3574.11"/>
+<polygon fill="black" stroke="black" points="1004,-3574.1 1000.5,-3564.1 997,-3574.1 1004,-3574.1"/>
+</g>
+<!-- 33 -->
+<g id="node22" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="1106,-3492 895,-3492 895,-3456 1106,-3456 1106,-3492"/>
+<text text-anchor="middle" x="1000.5" y="-3470.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 32&#45;&gt;33 -->
+<g id="edge4" class="edge">
+<title>32&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M1000.5,-3527.7C1000.5,-3519.98 1000.5,-3510.71 1000.5,-3502.11"/>
+<polygon fill="black" stroke="black" points="1004,-3502.1 1000.5,-3492.1 997,-3502.1 1004,-3502.1"/>
+</g>
+<!-- 33&#45;&gt;34 -->
+<g id="edge6" class="edge">
+<title>33&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M1040.86,-3455.88C1063.52,-3446.26 1091.97,-3434.19 1115.97,-3424.01"/>
+<polygon fill="black" stroke="black" points="1117.43,-3427.2 1125.27,-3420.07 1114.69,-3420.75 1117.43,-3427.2"/>
+</g>
+<!-- 36 -->
+<g id="node24" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="1387,-3348 944,-3348 944,-3312 1387,-3312 1387,-3348"/>
+<text text-anchor="middle" x="1165.5" y="-3326.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 34&#45;&gt;36 -->
+<g id="edge7" class="edge">
+<title>34&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M1165.5,-3383.7C1165.5,-3375.98 1165.5,-3366.71 1165.5,-3358.11"/>
+<polygon fill="black" stroke="black" points="1169,-3358.1 1165.5,-3348.1 1162,-3358.1 1169,-3358.1"/>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge8" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M1280.36,-3311.97C1380.49,-3297.12 1520.55,-3276.34 1591.86,-3265.77"/>
+<polygon fill="black" stroke="black" points="1592.49,-3269.21 1601.87,-3264.28 1591.47,-3262.29 1592.49,-3269.21"/>
+</g>
+<!-- 39 -->
+<g id="node26" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="1861,-3204 1482,-3204 1482,-3168 1861,-3168 1861,-3204"/>
+<text text-anchor="middle" x="1671.5" y="-3182.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 37&#45;&gt;39 -->
+<g id="edge10" class="edge">
+<title>37&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M1645.9,-3239.7C1649.77,-3231.73 1654.45,-3222.1 1658.74,-3213.26"/>
+<polygon fill="black" stroke="black" points="1661.97,-3214.63 1663.19,-3204.1 1655.67,-3211.57 1661.97,-3214.63"/>
+</g>
+<!-- 40 -->
+<g id="node27" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="1787,-3132 1558,-3132 1558,-3096 1787,-3096 1787,-3132"/>
+<text text-anchor="middle" x="1672.5" y="-3110.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 39&#45;&gt;40 -->
+<g id="edge11" class="edge">
+<title>39&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M1671.75,-3167.7C1671.86,-3159.98 1671.99,-3150.71 1672.11,-3142.11"/>
+<polygon fill="black" stroke="black" points="1675.61,-3142.15 1672.26,-3132.1 1668.61,-3142.05 1675.61,-3142.15"/>
+</g>
+<!-- 42 -->
+<g id="node28" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="1884,-2916 1463,-2916 1463,-2880 1884,-2880 1884,-2916"/>
+<text text-anchor="middle" x="1673.5" y="-2894.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 64]| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 40&#45;&gt;42 -->
+<g id="edge12" class="edge">
+<title>40&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M1672.58,-3095.85C1672.75,-3058.83 1673.16,-2971.18 1673.37,-2926.39"/>
+<polygon fill="black" stroke="black" points="1676.87,-2926.25 1673.42,-2916.23 1669.87,-2926.22 1676.87,-2926.25"/>
+</g>
+<!-- 60 -->
+<g id="node41" class="node">
+<title>60</title>
+<polygon fill="none" stroke="black" points="2117,-2844 1948,-2844 1948,-2808 2117,-2808 2117,-2844"/>
+<text text-anchor="middle" x="2032.5" y="-2822.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 42&#45;&gt;60 -->
+<g id="edge27" class="edge">
+<title>42&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M1760.86,-2879.97C1814.4,-2869.53 1882.94,-2856.16 1937.77,-2845.47"/>
+<polygon fill="black" stroke="black" points="1938.7,-2848.86 1947.85,-2843.51 1937.36,-2841.98 1938.7,-2848.86"/>
+</g>
+<!-- 49 -->
+<g id="node33" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="2655,-3420 2486,-3420 2486,-3384 2655,-3384 2655,-3420"/>
+<text text-anchor="middle" x="2570.5" y="-3398.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 44&#45;&gt;49 -->
+<g id="edge17" class="edge">
+<title>44&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M2351.07,-3455.97C2392.99,-3445.85 2446.29,-3432.98 2489.87,-3422.46"/>
+<polygon fill="black" stroke="black" points="2490.96,-3425.8 2499.86,-3420.05 2489.32,-3418.99 2490.96,-3425.8"/>
+</g>
+<!-- 47 -->
+<g id="node31" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="2859,-3564 2398,-3564 2398,-3528 2859,-3528 2859,-3564"/>
+<text text-anchor="middle" x="2628.5" y="-3542.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 45&#45;&gt;47 -->
+<g id="edge15" class="edge">
+<title>45&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M2628.5,-3599.7C2628.5,-3591.98 2628.5,-3582.71 2628.5,-3574.11"/>
+<polygon fill="black" stroke="black" points="2632,-3574.1 2628.5,-3564.1 2625,-3574.1 2632,-3574.1"/>
+</g>
+<!-- 48 -->
+<g id="node32" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="2734,-3492 2523,-3492 2523,-3456 2734,-3456 2734,-3492"/>
+<text text-anchor="middle" x="2628.5" y="-3470.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 47&#45;&gt;48 -->
+<g id="edge16" class="edge">
+<title>47&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M2628.5,-3527.7C2628.5,-3519.98 2628.5,-3510.71 2628.5,-3502.11"/>
+<polygon fill="black" stroke="black" points="2632,-3502.1 2628.5,-3492.1 2625,-3502.1 2632,-3502.1"/>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge18" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M2614.16,-3455.7C2607.21,-3447.3 2598.73,-3437.07 2591.1,-3427.86"/>
+<polygon fill="black" stroke="black" points="2593.75,-3425.57 2584.67,-3420.1 2588.36,-3430.04 2593.75,-3425.57"/>
+</g>
+<!-- 51 -->
+<g id="node34" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="2792,-3348 2349,-3348 2349,-3312 2792,-3312 2792,-3348"/>
+<text text-anchor="middle" x="2570.5" y="-3326.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 49&#45;&gt;51 -->
+<g id="edge19" class="edge">
+<title>49&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M2570.5,-3383.7C2570.5,-3375.98 2570.5,-3366.71 2570.5,-3358.11"/>
+<polygon fill="black" stroke="black" points="2574,-3358.1 2570.5,-3348.1 2567,-3358.1 2574,-3358.1"/>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge20" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M2458.32,-3311.97C2361.1,-3297.2 2225.35,-3276.59 2155.32,-3265.96"/>
+<polygon fill="black" stroke="black" points="2155.48,-3262.44 2145.07,-3264.4 2154.43,-3269.36 2155.48,-3262.44"/>
+</g>
+<!-- 54 -->
+<g id="node36" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="2296,-3204 1917,-3204 1917,-3168 2296,-3168 2296,-3204"/>
+<text text-anchor="middle" x="2106.5" y="-3182.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 52&#45;&gt;54 -->
+<g id="edge22" class="edge">
+<title>52&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M2108.76,-3239.7C2108.43,-3231.98 2108.03,-3222.71 2107.66,-3214.11"/>
+<polygon fill="black" stroke="black" points="2111.16,-3213.95 2107.23,-3204.1 2104.16,-3214.25 2111.16,-3213.95"/>
+</g>
+<!-- 55 -->
+<g id="node37" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="2210,-3132 1981,-3132 1981,-3096 2210,-3096 2210,-3132"/>
+<text text-anchor="middle" x="2095.5" y="-3110.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge23" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M2103.78,-3167.7C2102.57,-3159.98 2101.11,-3150.71 2099.76,-3142.11"/>
+<polygon fill="black" stroke="black" points="2103.2,-3141.44 2098.19,-3132.1 2096.28,-3142.53 2103.2,-3141.44"/>
+</g>
+<!-- 56 -->
+<g id="node38" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="2205,-3060 1976,-3060 1976,-3024 2205,-3024 2205,-3060"/>
+<text text-anchor="middle" x="2090.5" y="-3038.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 1, 3, 2])</text>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge24" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M2094.26,-3095.7C2093.71,-3087.98 2093.05,-3078.71 2092.44,-3070.11"/>
+<polygon fill="black" stroke="black" points="2095.93,-3069.83 2091.72,-3060.1 2088.94,-3070.33 2095.93,-3069.83"/>
+</g>
+<!-- 58 -->
+<g id="node39" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="2300,-2988 1879,-2988 1879,-2952 2300,-2952 2300,-2988"/>
+<text text-anchor="middle" x="2089.5" y="-2966.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 64 14]| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 56&#45;&gt;58 -->
+<g id="edge25" class="edge">
+<title>56&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M2090.25,-3023.7C2090.14,-3015.98 2090.01,-3006.71 2089.89,-2998.11"/>
+<polygon fill="black" stroke="black" points="2093.39,-2998.05 2089.74,-2988.1 2086.39,-2998.15 2093.39,-2998.05"/>
+</g>
+<!-- 59 -->
+<g id="node40" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="2160,-2916 1949,-2916 1949,-2880 2160,-2880 2160,-2916"/>
+<text text-anchor="middle" x="2054.5" y="-2894.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge26" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M2080.85,-2951.7C2076.86,-2943.73 2072.05,-2934.1 2067.63,-2925.26"/>
+<polygon fill="black" stroke="black" points="2070.65,-2923.48 2063.05,-2916.1 2064.39,-2926.61 2070.65,-2923.48"/>
+</g>
+<!-- 59&#45;&gt;60 -->
+<g id="edge28" class="edge">
+<title>59&#45;&gt;60</title>
+<path fill="none" stroke="black" d="M2049.06,-2879.7C2046.61,-2871.9 2043.66,-2862.51 2040.93,-2853.83"/>
+<polygon fill="black" stroke="black" points="2044.21,-2852.59 2037.88,-2844.1 2037.54,-2854.69 2044.21,-2852.59"/>
+</g>
+<!-- 62 -->
+<g id="node42" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="2222,-2772 1843,-2772 1843,-2736 2222,-2736 2222,-2772"/>
+<text text-anchor="middle" x="2032.5" y="-2750.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 12 14 14]| newshape=..., reverse=0)</text>
+</g>
+<!-- 60&#45;&gt;62 -->
+<g id="edge29" class="edge">
+<title>60&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M2032.5,-2807.7C2032.5,-2799.98 2032.5,-2790.71 2032.5,-2782.11"/>
+<polygon fill="black" stroke="black" points="2036,-2782.1 2032.5,-2772.1 2029,-2782.1 2036,-2782.1"/>
+</g>
+<!-- 64 -->
+<g id="node43" class="node">
+<title>64</title>
+<polygon fill="none" stroke="black" points="2086,-2700 1979,-2700 1979,-2664 2086,-2664 2086,-2700"/>
+<text text-anchor="middle" x="2032.5" y="-2678.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 62&#45;&gt;64 -->
+<g id="edge30" class="edge">
+<title>62&#45;&gt;64</title>
+<path fill="none" stroke="black" d="M2032.5,-2735.7C2032.5,-2727.98 2032.5,-2718.71 2032.5,-2710.11"/>
+<polygon fill="black" stroke="black" points="2036,-2710.1 2032.5,-2700.1 2029,-2710.1 2036,-2710.1"/>
+</g>
+<!-- 64&#45;&gt;65 -->
+<g id="edge31" class="edge">
+<title>64&#45;&gt;65</title>
+<path fill="none" stroke="black" d="M2021.62,-2663.7C2016.51,-2655.56 2010.3,-2645.69 2004.66,-2636.7"/>
+<polygon fill="black" stroke="black" points="2007.54,-2634.71 1999.25,-2628.1 2001.61,-2638.43 2007.54,-2634.71"/>
+</g>
+<!-- 66 -->
+<g id="node45" class="node">
+<title>66</title>
+<polygon fill="none" stroke="black" points="2075.5,-2556 1901.5,-2556 1901.5,-2520 2075.5,-2520 2075.5,-2556"/>
+<text text-anchor="middle" x="1988.5" y="-2534.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 65&#45;&gt;66 -->
+<g id="edge33" class="edge">
+<title>65&#45;&gt;66</title>
+<path fill="none" stroke="black" d="M1988.5,-2591.7C1988.5,-2583.98 1988.5,-2574.71 1988.5,-2566.11"/>
+<polygon fill="black" stroke="black" points="1992,-2566.1 1988.5,-2556.1 1985,-2566.1 1992,-2566.1"/>
+</g>
+<!-- 67 -->
+<g id="node46" class="node">
+<title>67</title>
+<polygon fill="none" stroke="black" points="2080,-2484 1897,-2484 1897,-2448 2080,-2448 2080,-2484"/>
+<text text-anchor="middle" x="1988.5" y="-2462.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 66&#45;&gt;67 -->
+<g id="edge34" class="edge">
+<title>66&#45;&gt;67</title>
+<path fill="none" stroke="black" d="M1988.5,-2519.7C1988.5,-2511.98 1988.5,-2502.71 1988.5,-2494.11"/>
+<polygon fill="black" stroke="black" points="1992,-2494.1 1988.5,-2484.1 1985,-2494.1 1992,-2494.1"/>
+</g>
+<!-- 68 -->
+<g id="node47" class="node">
+<title>68</title>
+<polygon fill="none" stroke="black" points="2072.5,-2412 1904.5,-2412 1904.5,-2376 2072.5,-2376 2072.5,-2412"/>
+<text text-anchor="middle" x="1988.5" y="-2390.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 67&#45;&gt;68 -->
+<g id="edge35" class="edge">
+<title>67&#45;&gt;68</title>
+<path fill="none" stroke="black" d="M1988.5,-2447.7C1988.5,-2439.98 1988.5,-2430.71 1988.5,-2422.11"/>
+<polygon fill="black" stroke="black" points="1992,-2422.1 1988.5,-2412.1 1985,-2422.1 1992,-2422.1"/>
+</g>
+<!-- 70 -->
+<g id="node48" class="node">
+<title>70</title>
+<polygon fill="none" stroke="black" points="2199,-2340 1778,-2340 1778,-2304 2199,-2304 2199,-2340"/>
+<text text-anchor="middle" x="1988.5" y="-2318.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 14]| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 68&#45;&gt;70 -->
+<g id="edge36" class="edge">
+<title>68&#45;&gt;70</title>
+<path fill="none" stroke="black" d="M1988.5,-2375.7C1988.5,-2367.98 1988.5,-2358.71 1988.5,-2350.11"/>
+<polygon fill="black" stroke="black" points="1992,-2350.1 1988.5,-2340.1 1985,-2350.1 1992,-2350.1"/>
+</g>
+<!-- 87 -->
+<g id="node60" class="node">
+<title>87</title>
+<polygon fill="none" stroke="black" points="2073,-2268 1904,-2268 1904,-2232 2073,-2232 2073,-2268"/>
+<text text-anchor="middle" x="1988.5" y="-2246.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 70&#45;&gt;87 -->
+<g id="edge50" class="edge">
+<title>70&#45;&gt;87</title>
+<path fill="none" stroke="black" d="M1988.5,-2303.7C1988.5,-2295.98 1988.5,-2286.71 1988.5,-2278.11"/>
+<polygon fill="black" stroke="black" points="1992,-2278.1 1988.5,-2268.1 1985,-2278.1 1992,-2278.1"/>
+</g>
+<!-- 77 -->
+<g id="node53" class="node">
+<title>77</title>
+<polygon fill="none" stroke="black" points="421,-3420 252,-3420 252,-3384 421,-3384 421,-3420"/>
+<text text-anchor="middle" x="336.5" y="-3398.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 72&#45;&gt;77 -->
+<g id="edge41" class="edge">
+<title>72&#45;&gt;77</title>
+<path fill="none" stroke="black" d="M1124.36,-3456.63C1121.39,-3456.42 1118.43,-3456.21 1115.5,-3456 868.89,-3438.61 578.3,-3419.1 431.55,-3409.32"/>
+<polygon fill="black" stroke="black" points="431.45,-3405.8 421.23,-3408.63 430.98,-3412.79 431.45,-3405.8"/>
+</g>
+<!-- 75 -->
+<g id="node51" class="node">
+<title>75</title>
+<polygon fill="none" stroke="black" points="567,-3564 106,-3564 106,-3528 567,-3528 567,-3564"/>
+<text text-anchor="middle" x="336.5" y="-3542.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 73&#45;&gt;75 -->
+<g id="edge39" class="edge">
+<title>73&#45;&gt;75</title>
+<path fill="none" stroke="black" d="M336.5,-3599.7C336.5,-3591.98 336.5,-3582.71 336.5,-3574.11"/>
+<polygon fill="black" stroke="black" points="340,-3574.1 336.5,-3564.1 333,-3574.1 340,-3574.1"/>
+</g>
+<!-- 76 -->
+<g id="node52" class="node">
+<title>76</title>
+<polygon fill="none" stroke="black" points="442,-3492 231,-3492 231,-3456 442,-3456 442,-3492"/>
+<text text-anchor="middle" x="336.5" y="-3470.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 75&#45;&gt;76 -->
+<g id="edge40" class="edge">
+<title>75&#45;&gt;76</title>
+<path fill="none" stroke="black" d="M336.5,-3527.7C336.5,-3519.98 336.5,-3510.71 336.5,-3502.11"/>
+<polygon fill="black" stroke="black" points="340,-3502.1 336.5,-3492.1 333,-3502.1 340,-3502.1"/>
+</g>
+<!-- 76&#45;&gt;77 -->
+<g id="edge42" class="edge">
+<title>76&#45;&gt;77</title>
+<path fill="none" stroke="black" d="M336.5,-3455.7C336.5,-3447.98 336.5,-3438.71 336.5,-3430.11"/>
+<polygon fill="black" stroke="black" points="340,-3430.1 336.5,-3420.1 333,-3430.1 340,-3430.1"/>
+</g>
+<!-- 79 -->
+<g id="node54" class="node">
+<title>79</title>
+<polygon fill="none" stroke="black" points="443,-3348 0,-3348 0,-3312 443,-3312 443,-3348"/>
+<text text-anchor="middle" x="221.5" y="-3326.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 77&#45;&gt;79 -->
+<g id="edge43" class="edge">
+<title>77&#45;&gt;79</title>
+<path fill="none" stroke="black" d="M308.37,-3383.88C293.33,-3374.72 274.63,-3363.34 258.44,-3353.48"/>
+<polygon fill="black" stroke="black" points="260.1,-3350.4 249.74,-3348.19 256.46,-3356.38 260.1,-3350.4"/>
+</g>
+<!-- 79&#45;&gt;80 -->
+<g id="edge44" class="edge">
+<title>79&#45;&gt;80</title>
+<path fill="none" stroke="black" d="M336.36,-3311.97C436.49,-3297.12 576.55,-3276.34 647.86,-3265.77"/>
+<polygon fill="black" stroke="black" points="648.49,-3269.21 657.87,-3264.28 647.47,-3262.29 648.49,-3269.21"/>
+</g>
+<!-- 82 -->
+<g id="node56" class="node">
+<title>82</title>
+<polygon fill="none" stroke="black" points="1156,-3204 777,-3204 777,-3168 1156,-3168 1156,-3204"/>
+<text text-anchor="middle" x="966.5" y="-3182.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 80&#45;&gt;82 -->
+<g id="edge46" class="edge">
+<title>80&#45;&gt;82</title>
+<path fill="none" stroke="black" d="M729.19,-3247.85C769.86,-3237.42 837.49,-3220.08 890.18,-3206.57"/>
+<polygon fill="black" stroke="black" points="891.35,-3209.88 900.17,-3204.01 889.61,-3203.1 891.35,-3209.88"/>
+</g>
+<!-- 83 -->
+<g id="node57" class="node">
+<title>83</title>
+<polygon fill="none" stroke="black" points="1293,-3132 1064,-3132 1064,-3096 1293,-3096 1293,-3132"/>
+<text text-anchor="middle" x="1178.5" y="-3110.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 82&#45;&gt;83 -->
+<g id="edge47" class="edge">
+<title>82&#45;&gt;83</title>
+<path fill="none" stroke="black" d="M1018.09,-3167.97C1047.97,-3158.1 1085.76,-3145.62 1117.16,-3135.25"/>
+<polygon fill="black" stroke="black" points="1118.28,-3138.57 1126.68,-3132.11 1116.09,-3131.92 1118.28,-3138.57"/>
+</g>
+<!-- 85 -->
+<g id="node58" class="node">
+<title>85</title>
+<polygon fill="none" stroke="black" points="1499,-3060 1078,-3060 1078,-3024 1499,-3024 1499,-3060"/>
+<text text-anchor="middle" x="1288.5" y="-3038.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 64]| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 83&#45;&gt;85 -->
+<g id="edge48" class="edge">
+<title>83&#45;&gt;85</title>
+<path fill="none" stroke="black" d="M1205.41,-3095.88C1219.66,-3086.81 1237.35,-3075.55 1252.74,-3065.76"/>
+<polygon fill="black" stroke="black" points="1254.93,-3068.51 1261.49,-3060.19 1251.17,-3062.61 1254.93,-3068.51"/>
+</g>
+<!-- 86 -->
+<g id="node59" class="node">
+<title>86</title>
+<polygon fill="none" stroke="black" points="1558,-2844 1347,-2844 1347,-2808 1558,-2808 1558,-2844"/>
+<text text-anchor="middle" x="1452.5" y="-2822.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 85&#45;&gt;86 -->
+<g id="edge49" class="edge">
+<title>85&#45;&gt;86</title>
+<path fill="none" stroke="black" d="M1301.64,-3023.85C1330.43,-2986.28 1399.18,-2896.58 1433.01,-2852.43"/>
+<polygon fill="black" stroke="black" points="1435.99,-2854.3 1439.29,-2844.23 1430.43,-2850.04 1435.99,-2854.3"/>
+</g>
+<!-- 86&#45;&gt;87 -->
+<g id="edge51" class="edge">
+<title>86&#45;&gt;87</title>
+<path fill="none" stroke="black" d="M1453.69,-2807.96C1455.43,-2781.31 1458.5,-2728.15 1458.5,-2683 1458.5,-2683 1458.5,-2683 1458.5,-2393 1458.5,-2303.95 1741.16,-2269.24 1893.63,-2256.98"/>
+<polygon fill="black" stroke="black" points="1894.24,-2260.45 1903.93,-2256.18 1893.69,-2253.47 1894.24,-2260.45"/>
+</g>
+<!-- 89 -->
+<g id="node61" class="node">
+<title>89</title>
+<polygon fill="none" stroke="black" points="2212,-2196 1833,-2196 1833,-2160 2212,-2160 2212,-2196"/>
+<text text-anchor="middle" x="2022.5" y="-2174.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 12 14 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 87&#45;&gt;89 -->
+<g id="edge52" class="edge">
+<title>87&#45;&gt;89</title>
+<path fill="none" stroke="black" d="M1996.9,-2231.7C2000.77,-2223.73 2005.45,-2214.1 2009.74,-2205.26"/>
+<polygon fill="black" stroke="black" points="2012.97,-2206.63 2014.19,-2196.1 2006.67,-2203.57 2012.97,-2206.63"/>
+</g>
+<!-- 90 -->
+<g id="node62" class="node">
+<title>90</title>
+<polygon fill="none" stroke="black" points="2146,-2124 1917,-2124 1917,-2088 2146,-2088 2146,-2124"/>
+<text text-anchor="middle" x="2031.5" y="-2102.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 89&#45;&gt;90 -->
+<g id="edge53" class="edge">
+<title>89&#45;&gt;90</title>
+<path fill="none" stroke="black" d="M2024.72,-2159.7C2025.72,-2151.98 2026.91,-2142.71 2028.01,-2134.11"/>
+<polygon fill="black" stroke="black" points="2031.5,-2134.47 2029.3,-2124.1 2024.55,-2133.58 2031.5,-2134.47"/>
+</g>
+<!-- 91 -->
+<g id="node63" class="node">
+<title>91</title>
+<polygon fill="none" stroke="black" points="2081,-2052 2016,-2052 2016,-2016 2081,-2016 2081,-2052"/>
+<text text-anchor="middle" x="2048.5" y="-2030.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 90&#45;&gt;91 -->
+<g id="edge54" class="edge">
+<title>90&#45;&gt;91</title>
+<path fill="none" stroke="black" d="M2035.7,-2087.7C2037.6,-2079.9 2039.88,-2070.51 2041.98,-2061.83"/>
+<polygon fill="black" stroke="black" points="2045.39,-2062.65 2044.35,-2052.1 2038.58,-2061 2045.39,-2062.65"/>
+</g>
+<!-- 93 -->
+<g id="node64" class="node">
+<title>93</title>
+<polygon fill="none" stroke="black" points="2270,-1980 1827,-1980 1827,-1944 2270,-1944 2270,-1980"/>
+<text text-anchor="middle" x="2048.5" y="-1958.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 91&#45;&gt;93 -->
+<g id="edge55" class="edge">
+<title>91&#45;&gt;93</title>
+<path fill="none" stroke="black" d="M2048.5,-2015.7C2048.5,-2007.98 2048.5,-1998.71 2048.5,-1990.11"/>
+<polygon fill="black" stroke="black" points="2052,-1990.1 2048.5,-1980.1 2045,-1990.1 2052,-1990.1"/>
+</g>
+<!-- 95 -->
+<g id="node65" class="node">
+<title>95</title>
+<polygon fill="none" stroke="black" points="2272.5,-1908 1824.5,-1908 1824.5,-1872 2272.5,-1872 2272.5,-1908"/>
+<text text-anchor="middle" x="2048.5" y="-1886.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 93&#45;&gt;95 -->
+<g id="edge56" class="edge">
+<title>93&#45;&gt;95</title>
+<path fill="none" stroke="black" d="M2048.5,-1943.7C2048.5,-1935.98 2048.5,-1926.71 2048.5,-1918.11"/>
+<polygon fill="black" stroke="black" points="2052,-1918.1 2048.5,-1908.1 2045,-1918.1 2052,-1918.1"/>
+</g>
+<!-- 100 -->
+<g id="node69" class="node">
+<title>100</title>
+<polygon fill="none" stroke="black" points="2133,-1836 1964,-1836 1964,-1800 2133,-1800 2133,-1836"/>
+<text text-anchor="middle" x="2048.5" y="-1814.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 95&#45;&gt;100 -->
+<g id="edge60" class="edge">
+<title>95&#45;&gt;100</title>
+<path fill="none" stroke="black" d="M2048.5,-1871.7C2048.5,-1863.98 2048.5,-1854.71 2048.5,-1846.11"/>
+<polygon fill="black" stroke="black" points="2052,-1846.1 2048.5,-1836.1 2045,-1846.1 2052,-1846.1"/>
+</g>
+<!-- 98 -->
+<g id="node67" class="node">
+<title>98</title>
+<polygon fill="none" stroke="black" points="1809,-1980 1348,-1980 1348,-1944 1809,-1944 1809,-1980"/>
+<text text-anchor="middle" x="1578.5" y="-1958.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 96&#45;&gt;98 -->
+<g id="edge58" class="edge">
+<title>96&#45;&gt;98</title>
+<path fill="none" stroke="black" d="M1578.5,-2015.7C1578.5,-2007.98 1578.5,-1998.71 1578.5,-1990.11"/>
+<polygon fill="black" stroke="black" points="1582,-1990.1 1578.5,-1980.1 1575,-1990.1 1582,-1990.1"/>
+</g>
+<!-- 99 -->
+<g id="node68" class="node">
+<title>99</title>
+<polygon fill="none" stroke="black" points="1745,-1908 1534,-1908 1534,-1872 1745,-1872 1745,-1908"/>
+<text text-anchor="middle" x="1639.5" y="-1886.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 98&#45;&gt;99 -->
+<g id="edge59" class="edge">
+<title>98&#45;&gt;99</title>
+<path fill="none" stroke="black" d="M1593.58,-1943.7C1600.89,-1935.3 1609.81,-1925.07 1617.84,-1915.86"/>
+<polygon fill="black" stroke="black" points="1620.66,-1917.94 1624.59,-1908.1 1615.39,-1913.34 1620.66,-1917.94"/>
+</g>
+<!-- 99&#45;&gt;100 -->
+<g id="edge61" class="edge">
+<title>99&#45;&gt;100</title>
+<path fill="none" stroke="black" d="M1739.03,-1871.97C1804.34,-1860.79 1889.26,-1846.25 1953.61,-1835.24"/>
+<polygon fill="black" stroke="black" points="1954.49,-1838.64 1963.76,-1833.5 1953.31,-1831.74 1954.49,-1838.64"/>
+</g>
+<!-- 102 -->
+<g id="node70" class="node">
+<title>102</title>
+<polygon fill="none" stroke="black" points="2274,-1764 1831,-1764 1831,-1728 2274,-1728 2274,-1764"/>
+<text text-anchor="middle" x="2052.5" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 100&#45;&gt;102 -->
+<g id="edge62" class="edge">
+<title>100&#45;&gt;102</title>
+<path fill="none" stroke="black" d="M2049.49,-1799.7C2049.93,-1791.98 2050.46,-1782.71 2050.95,-1774.11"/>
+<polygon fill="black" stroke="black" points="2054.45,-1774.29 2051.52,-1764.1 2047.46,-1773.89 2054.45,-1774.29"/>
+</g>
+<!-- 102&#45;&gt;103 -->
+<g id="edge63" class="edge">
+<title>102&#45;&gt;103</title>
+<path fill="none" stroke="black" d="M2171.74,-1727.97C2276.37,-1713.02 2422.98,-1692.07 2496.55,-1681.56"/>
+<polygon fill="black" stroke="black" points="2497.46,-1684.97 2506.87,-1680.09 2496.47,-1678.04 2497.46,-1684.97"/>
+</g>
+<!-- 104 -->
+<g id="node72" class="node">
+<title>104</title>
+<polygon fill="none" stroke="black" points="2714,-1620 2531,-1620 2531,-1584 2714,-1584 2714,-1620"/>
+<text text-anchor="middle" x="2622.5" y="-1598.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 103&#45;&gt;104 -->
+<g id="edge65" class="edge">
+<title>103&#45;&gt;104</title>
+<path fill="none" stroke="black" d="M2562.28,-1655.7C2572.25,-1646.97 2584.51,-1636.24 2595.36,-1626.75"/>
+<polygon fill="black" stroke="black" points="2597.73,-1629.32 2602.95,-1620.1 2593.12,-1624.06 2597.73,-1629.32"/>
+</g>
+<!-- 105 -->
+<g id="node73" class="node">
+<title>105</title>
+<polygon fill="none" stroke="black" points="2788.5,-1548 2620.5,-1548 2620.5,-1512 2788.5,-1512 2788.5,-1548"/>
+<text text-anchor="middle" x="2704.5" y="-1526.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 104&#45;&gt;105 -->
+<g id="edge66" class="edge">
+<title>104&#45;&gt;105</title>
+<path fill="none" stroke="black" d="M2642.77,-1583.7C2653,-1574.97 2665.56,-1564.24 2676.68,-1554.75"/>
+<polygon fill="black" stroke="black" points="2679.13,-1557.26 2684.46,-1548.1 2674.59,-1551.94 2679.13,-1557.26"/>
+</g>
+<!-- 105&#45;&gt;106 -->
+<g id="edge67" class="edge">
+<title>105&#45;&gt;106</title>
+<path fill="none" stroke="black" d="M2732.88,-1511.88C2748.04,-1502.72 2766.91,-1491.34 2783.24,-1481.48"/>
+<polygon fill="black" stroke="black" points="2785.26,-1484.35 2792.01,-1476.19 2781.64,-1478.36 2785.26,-1484.35"/>
+</g>
+<!-- 106&#45;&gt;107 -->
+<g id="edge69" class="edge">
+<title>106&#45;&gt;107</title>
+<path fill="none" stroke="black" d="M2784.82,-1442.89C2781.69,-1441.85 2778.55,-1440.86 2775.5,-1440 2725.82,-1425.94 2670.04,-1414.43 2621.77,-1405.76"/>
+<polygon fill="black" stroke="black" points="2622.36,-1402.31 2611.91,-1404.01 2621.14,-1409.21 2622.36,-1402.31"/>
+</g>
+<!-- 109 -->
+<g id="node76" class="node">
+<title>109</title>
+<polygon fill="none" stroke="black" points="2980.5,-1332 2532.5,-1332 2532.5,-1296 2980.5,-1296 2980.5,-1332"/>
+<text text-anchor="middle" x="2756.5" y="-1310.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 107&#45;&gt;109 -->
+<g id="edge72" class="edge">
+<title>107&#45;&gt;109</title>
+<path fill="none" stroke="black" d="M2562.8,-1367.97C2599.34,-1357.97 2645.69,-1345.3 2683.89,-1334.85"/>
+<polygon fill="black" stroke="black" points="2685.2,-1338.12 2693.92,-1332.11 2683.35,-1331.37 2685.2,-1338.12"/>
+</g>
+<!-- 139 -->
+<g id="node97" class="node">
+<title>139</title>
+<polygon fill="none" stroke="black" points="2594,-252 2523,-252 2523,-216 2594,-216 2594,-252"/>
+<text text-anchor="middle" x="2558.5" y="-230.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 107&#45;&gt;139 -->
+<g id="edge99" class="edge">
+<title>107&#45;&gt;139</title>
+<path fill="none" stroke="black" d="M2500.5,-1367.95C2500.5,-1341.29 2500.5,-1288.11 2500.5,-1243 2500.5,-1243 2500.5,-1243 2500.5,-377 2500.5,-336.14 2506.91,-325.33 2523.5,-288 2527.69,-278.58 2533.49,-268.97 2539.16,-260.56"/>
+<polygon fill="black" stroke="black" points="2542.08,-262.5 2544.96,-252.3 2536.35,-258.48 2542.08,-262.5"/>
+</g>
+<!-- 114 -->
+<g id="node80" class="node">
+<title>114</title>
+<polygon fill="none" stroke="black" points="3017,-1260 2848,-1260 2848,-1224 3017,-1224 3017,-1260"/>
+<text text-anchor="middle" x="2932.5" y="-1238.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 109&#45;&gt;114 -->
+<g id="edge76" class="edge">
+<title>109&#45;&gt;114</title>
+<path fill="none" stroke="black" d="M2799.55,-1295.88C2823.83,-1286.22 2854.34,-1274.09 2880.01,-1263.88"/>
+<polygon fill="black" stroke="black" points="2881.59,-1267.02 2889.59,-1260.07 2879,-1260.51 2881.59,-1267.02"/>
+</g>
+<!-- 112 -->
+<g id="node78" class="node">
+<title>112</title>
+<polygon fill="none" stroke="black" points="3314,-1404 2917,-1404 2917,-1368 3314,-1368 3314,-1404"/>
+<text text-anchor="middle" x="3115.5" y="-1382.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 &#160;768 3072]| newshape=..., reverse=0)</text>
+</g>
+<!-- 110&#45;&gt;112 -->
+<g id="edge74" class="edge">
+<title>110&#45;&gt;112</title>
+<path fill="none" stroke="black" d="M3117.01,-1439.7C3116.79,-1431.98 3116.52,-1422.71 3116.27,-1414.11"/>
+<polygon fill="black" stroke="black" points="3119.77,-1414 3115.99,-1404.1 3112.78,-1414.2 3119.77,-1414"/>
+</g>
+<!-- 113 -->
+<g id="node79" class="node">
+<title>113</title>
+<polygon fill="none" stroke="black" points="3214,-1332 3003,-1332 3003,-1296 3214,-1296 3214,-1332"/>
+<text text-anchor="middle" x="3108.5" y="-1310.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 112&#45;&gt;113 -->
+<g id="edge75" class="edge">
+<title>112&#45;&gt;113</title>
+<path fill="none" stroke="black" d="M3113.77,-1367.7C3113,-1359.98 3112.07,-1350.71 3111.21,-1342.11"/>
+<polygon fill="black" stroke="black" points="3114.69,-1341.71 3110.21,-1332.1 3107.72,-1342.4 3114.69,-1341.71"/>
+</g>
+<!-- 113&#45;&gt;114 -->
+<g id="edge77" class="edge">
+<title>113&#45;&gt;114</title>
+<path fill="none" stroke="black" d="M3065.45,-1295.88C3041.17,-1286.22 3010.66,-1274.09 2984.99,-1263.88"/>
+<polygon fill="black" stroke="black" points="2986,-1260.51 2975.41,-1260.07 2983.41,-1267.02 2986,-1260.51"/>
+</g>
+<!-- 116 -->
+<g id="node81" class="node">
+<title>116</title>
+<polygon fill="none" stroke="black" points="3167.5,-1188 2697.5,-1188 2697.5,-1152 3167.5,-1152 3167.5,-1188"/>
+<text text-anchor="middle" x="2932.5" y="-1166.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#160;1 &#160;&#160;14 3072]| newshape=[1, 14, 3072], reverse=0)</text>
+</g>
+<!-- 114&#45;&gt;116 -->
+<g id="edge78" class="edge">
+<title>114&#45;&gt;116</title>
+<path fill="none" stroke="black" d="M2932.5,-1223.7C2932.5,-1215.98 2932.5,-1206.71 2932.5,-1198.11"/>
+<polygon fill="black" stroke="black" points="2936,-1198.1 2932.5,-1188.1 2929,-1198.1 2936,-1198.1"/>
+</g>
+<!-- 116&#45;&gt;117 -->
+<g id="edge79" class="edge">
+<title>116&#45;&gt;117</title>
+<path fill="none" stroke="black" d="M2983.85,-1151.97C3018.82,-1140.36 3064.69,-1125.15 3098.28,-1114"/>
+<polygon fill="black" stroke="black" points="3099.46,-1117.3 3107.85,-1110.83 3097.25,-1110.65 3099.46,-1117.3"/>
+</g>
+<!-- 121 -->
+<g id="node83" class="node">
+<title>121</title>
+<polygon fill="none" stroke="black" points="3296,-1044 3111,-1044 3111,-1008 3296,-1008 3296,-1044"/>
+<text text-anchor="middle" x="3203.5" y="-1022.3" font-family="Times,serif" font-size="14.00">multiply(·, 0.70710677)</text>
+</g>
+<!-- 117&#45;&gt;121 -->
+<g id="edge81" class="edge">
+<title>117&#45;&gt;121</title>
+<path fill="none" stroke="black" d="M3158.33,-1079.7C3165.52,-1071.3 3174.3,-1061.07 3182.19,-1051.86"/>
+<polygon fill="black" stroke="black" points="3184.99,-1053.97 3188.84,-1044.1 3179.67,-1049.42 3184.99,-1053.97"/>
+</g>
+<!-- 126 -->
+<g id="node87" class="node">
+<title>126</title>
+<polygon fill="none" stroke="black" points="3175.5,-756 3071.5,-756 3071.5,-720 3175.5,-720 3175.5,-756"/>
+<text text-anchor="middle" x="3123.5" y="-734.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 117&#45;&gt;126 -->
+<g id="edge85" class="edge">
+<title>117&#45;&gt;126</title>
+<path fill="none" stroke="black" d="M3126.88,-1079.7C3118.38,-1069.93 3108.57,-1057.09 3102.5,-1044 3085.48,-1007.31 3083.5,-995.45 3083.5,-955 3083.5,-955 3083.5,-955 3083.5,-881 3083.5,-839.48 3099.54,-793.61 3111.3,-765.54"/>
+<polygon fill="black" stroke="black" points="3114.65,-766.62 3115.41,-756.05 3108.23,-763.84 3114.65,-766.62"/>
+</g>
+<!-- 122 -->
+<g id="node84" class="node">
+<title>122</title>
+<polygon fill="none" stroke="black" points="3222.5,-972 3168.5,-972 3168.5,-936 3222.5,-936 3222.5,-972"/>
+<text text-anchor="middle" x="3195.5" y="-950.3" font-family="Times,serif" font-size="14.00">erf(·)</text>
+</g>
+<!-- 121&#45;&gt;122 -->
+<g id="edge82" class="edge">
+<title>121&#45;&gt;122</title>
+<path fill="none" stroke="black" d="M3201.52,-1007.7C3200.64,-999.98 3199.58,-990.71 3198.6,-982.11"/>
+<polygon fill="black" stroke="black" points="3202.07,-981.64 3197.45,-972.1 3195.11,-982.44 3202.07,-981.64"/>
+</g>
+<!-- 124 -->
+<g id="node85" class="node">
+<title>124</title>
+<polygon fill="none" stroke="black" points="3248.5,-900 3126.5,-900 3126.5,-864 3248.5,-864 3248.5,-900"/>
+<text text-anchor="middle" x="3187.5" y="-878.3" font-family="Times,serif" font-size="14.00">multiply(·, 0.5)</text>
+</g>
+<!-- 122&#45;&gt;124 -->
+<g id="edge83" class="edge">
+<title>122&#45;&gt;124</title>
+<path fill="none" stroke="black" d="M3193.52,-935.7C3192.64,-927.98 3191.58,-918.71 3190.6,-910.11"/>
+<polygon fill="black" stroke="black" points="3194.07,-909.64 3189.45,-900.1 3187.11,-910.44 3194.07,-909.64"/>
+</g>
+<!-- 125 -->
+<g id="node86" class="node">
+<title>125</title>
+<polygon fill="none" stroke="black" points="3209,-828 3120,-828 3120,-792 3209,-792 3209,-828"/>
+<text text-anchor="middle" x="3164.5" y="-806.3" font-family="Times,serif" font-size="14.00">add(0.5, ·)</text>
+</g>
+<!-- 124&#45;&gt;125 -->
+<g id="edge84" class="edge">
+<title>124&#45;&gt;125</title>
+<path fill="none" stroke="black" d="M3181.81,-863.7C3179.25,-855.9 3176.17,-846.51 3173.32,-837.83"/>
+<polygon fill="black" stroke="black" points="3176.57,-836.51 3170.12,-828.1 3169.92,-838.7 3176.57,-836.51"/>
+</g>
+<!-- 125&#45;&gt;126 -->
+<g id="edge86" class="edge">
+<title>125&#45;&gt;126</title>
+<path fill="none" stroke="black" d="M3154.37,-791.7C3149.65,-783.64 3143.94,-773.89 3138.72,-764.98"/>
+<polygon fill="black" stroke="black" points="3141.59,-762.96 3133.52,-756.1 3135.55,-766.5 3141.59,-762.96"/>
+</g>
+<!-- 128 -->
+<g id="node88" class="node">
+<title>128</title>
+<polygon fill="none" stroke="black" points="3361,-684 2886,-684 2886,-648 3361,-648 3361,-684"/>
+<text text-anchor="middle" x="3123.5" y="-662.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 &#160;&#160;14 3072]| newshape=[&#45;1, 14, 3072], reverse=0)</text>
+</g>
+<!-- 126&#45;&gt;128 -->
+<g id="edge87" class="edge">
+<title>126&#45;&gt;128</title>
+<path fill="none" stroke="black" d="M3123.5,-719.7C3123.5,-711.98 3123.5,-702.71 3123.5,-694.11"/>
+<polygon fill="black" stroke="black" points="3127,-694.1 3123.5,-684.1 3120,-694.1 3127,-694.1"/>
+</g>
+<!-- 133 -->
+<g id="node92" class="node">
+<title>133</title>
+<polygon fill="none" stroke="black" points="2847,-612 2678,-612 2678,-576 2847,-576 2847,-612"/>
+<text text-anchor="middle" x="2762.5" y="-590.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 128&#45;&gt;133 -->
+<g id="edge91" class="edge">
+<title>128&#45;&gt;133</title>
+<path fill="none" stroke="black" d="M3035.65,-647.97C2981.62,-637.49 2912.39,-624.07 2857.15,-613.35"/>
+<polygon fill="black" stroke="black" points="2857.49,-609.85 2847.01,-611.39 2856.16,-616.73 2857.49,-609.85"/>
+</g>
+<!-- 131 -->
+<g id="node90" class="node">
+<title>131</title>
+<polygon fill="none" stroke="black" points="2963,-756 2566,-756 2566,-720 2963,-720 2963,-756"/>
+<text text-anchor="middle" x="2764.5" y="-734.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 3072 &#160;768]| newshape=..., reverse=0)</text>
+</g>
+<!-- 129&#45;&gt;131 -->
+<g id="edge89" class="edge">
+<title>129&#45;&gt;131</title>
+<path fill="none" stroke="black" d="M2765.25,-791.7C2765.14,-783.98 2765.01,-774.71 2764.89,-766.11"/>
+<polygon fill="black" stroke="black" points="2768.39,-766.05 2764.74,-756.1 2761.39,-766.15 2768.39,-766.05"/>
+</g>
+<!-- 132 -->
+<g id="node91" class="node">
+<title>132</title>
+<polygon fill="none" stroke="black" points="2868,-684 2657,-684 2657,-648 2868,-648 2868,-684"/>
+<text text-anchor="middle" x="2762.5" y="-662.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 131&#45;&gt;132 -->
+<g id="edge90" class="edge">
+<title>131&#45;&gt;132</title>
+<path fill="none" stroke="black" d="M2764.01,-719.7C2763.79,-711.98 2763.52,-702.71 2763.27,-694.11"/>
+<polygon fill="black" stroke="black" points="2766.77,-694 2762.99,-684.1 2759.78,-694.2 2766.77,-694"/>
+</g>
+<!-- 132&#45;&gt;133 -->
+<g id="edge92" class="edge">
+<title>132&#45;&gt;133</title>
+<path fill="none" stroke="black" d="M2762.5,-647.7C2762.5,-639.98 2762.5,-630.71 2762.5,-622.11"/>
+<polygon fill="black" stroke="black" points="2766,-622.1 2762.5,-612.1 2759,-622.1 2766,-622.1"/>
+</g>
+<!-- 135 -->
+<g id="node93" class="node">
+<title>135</title>
+<polygon fill="none" stroke="black" points="2982,-540 2539,-540 2539,-504 2982,-504 2982,-540"/>
+<text text-anchor="middle" x="2760.5" y="-518.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 133&#45;&gt;135 -->
+<g id="edge93" class="edge">
+<title>133&#45;&gt;135</title>
+<path fill="none" stroke="black" d="M2762.01,-575.7C2761.79,-567.98 2761.52,-558.71 2761.27,-550.11"/>
+<polygon fill="black" stroke="black" points="2764.77,-550 2760.99,-540.1 2757.78,-550.2 2764.77,-550"/>
+</g>
+<!-- 135&#45;&gt;136 -->
+<g id="edge94" class="edge">
+<title>135&#45;&gt;136</title>
+<path fill="none" stroke="black" d="M2760.5,-503.7C2760.5,-495.98 2760.5,-486.71 2760.5,-478.11"/>
+<polygon fill="black" stroke="black" points="2764,-478.1 2760.5,-468.1 2757,-478.1 2764,-478.1"/>
+</g>
+<!-- 137 -->
+<g id="node95" class="node">
+<title>137</title>
+<polygon fill="none" stroke="black" points="2777,-396 2594,-396 2594,-360 2777,-360 2777,-396"/>
+<text text-anchor="middle" x="2685.5" y="-374.3" font-family="Times,serif" font-size="14.00">nn.dropout(·| rate=0.1)</text>
+</g>
+<!-- 136&#45;&gt;137 -->
+<g id="edge96" class="edge">
+<title>136&#45;&gt;137</title>
+<path fill="none" stroke="black" d="M2741.96,-431.7C2732.7,-423.05 2721.34,-412.45 2711.24,-403.03"/>
+<polygon fill="black" stroke="black" points="2713.52,-400.37 2703.83,-396.1 2708.75,-405.49 2713.52,-400.37"/>
+</g>
+<!-- 138 -->
+<g id="node96" class="node">
+<title>138</title>
+<polygon fill="none" stroke="black" points="2700.5,-324 2532.5,-324 2532.5,-288 2700.5,-288 2700.5,-324"/>
+<text text-anchor="middle" x="2616.5" y="-302.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 137&#45;&gt;138 -->
+<g id="edge97" class="edge">
+<title>137&#45;&gt;138</title>
+<path fill="none" stroke="black" d="M2668.44,-359.7C2660,-351.14 2649.68,-340.66 2640.46,-331.3"/>
+<polygon fill="black" stroke="black" points="2642.87,-328.77 2633.36,-324.1 2637.89,-333.68 2642.87,-328.77"/>
+</g>
+<!-- 138&#45;&gt;139 -->
+<g id="edge98" class="edge">
+<title>138&#45;&gt;139</title>
+<path fill="none" stroke="black" d="M2602.16,-287.7C2595.21,-279.3 2586.73,-269.07 2579.1,-259.86"/>
+<polygon fill="black" stroke="black" points="2581.75,-257.57 2572.67,-252.1 2576.36,-262.04 2581.75,-257.57"/>
+</g>
+<!-- 139&#45;&gt;140 -->
+<g id="edge100" class="edge">
+<title>139&#45;&gt;140</title>
+<path fill="none" stroke="black" d="M2522.63,-218.64C2519.9,-217.7 2517.16,-216.81 2514.5,-216 2470.24,-202.6 2420.42,-190.99 2377.83,-182.08"/>
+<polygon fill="black" stroke="black" points="2378.52,-178.64 2368.02,-180.04 2377.1,-185.5 2378.52,-178.64"/>
+</g>
+<!-- 141 -->
+<g id="node99" class="node">
+<title>141</title>
+<polygon fill="none" stroke="black" points="2319.5,-108 2233.5,-108 2233.5,-72 2319.5,-72 2319.5,-108"/>
+<text text-anchor="middle" x="2276.5" y="-86.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 140&#45;&gt;141 -->
+<g id="edge103" class="edge">
+<title>140&#45;&gt;141</title>
+<path fill="none" stroke="black" d="M2276.5,-143.7C2276.5,-135.98 2276.5,-126.71 2276.5,-118.11"/>
+<polygon fill="black" stroke="black" points="2280,-118.1 2276.5,-108.1 2273,-118.1 2280,-118.1"/>
+</g>
+<!-- 142 -->
+<g id="node100" class="node">
+<title>142</title>
+<polygon fill="none" stroke="black" points="2316.5,-36 2236.5,-36 2236.5,0 2316.5,0 2316.5,-36"/>
+<text text-anchor="middle" x="2276.5" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 141&#45;&gt;142 -->
+<g id="edge104" class="edge">
+<title>141&#45;&gt;142</title>
+<path fill="none" stroke="black" d="M2276.5,-71.7C2276.5,-63.98 2276.5,-54.71 2276.5,-46.11"/>
+<polygon fill="black" stroke="black" points="2280,-46.1 2276.5,-36.1 2273,-46.1 2280,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/pytorch-tvm-training_25_0.svg b/images/bert-pytorch/pytorch-tvm-training_25_0.svg
new file mode 100644
index 0000000..707f26c
--- /dev/null
+++ b/images/bert-pytorch/pytorch-tvm-training_25_0.svg
@@ -0,0 +1,1537 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="3227pt" height="4508pt"
+ viewBox="0.00 0.00 3226.81 4508.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 4504)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-4504 3222.81,-4504 3222.81,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="1019.18" cy="-4338" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="1019.18" y="-4334.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 33 -->
+<g id="node23" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="1476.18,-4284 1028.18,-4284 1028.18,-4248 1476.18,-4248 1476.18,-4284"/>
+<text text-anchor="middle" x="1252.18" y="-4262.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;33 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M1072.92,-4320.85C1106.46,-4310.78 1149.89,-4297.73 1185.6,-4287"/>
+<polygon fill="black" stroke="black" points="1186.88,-4290.27 1195.45,-4284.04 1184.86,-4283.57 1186.88,-4290.27"/>
+</g>
+<!-- 47 -->
+<g id="node33" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="1942.18,-4284 1494.18,-4284 1494.18,-4248 1942.18,-4248 1942.18,-4284"/>
+<text text-anchor="middle" x="1718.18" y="-4262.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;47 -->
+<g id="edge13" class="edge">
+<title>0&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M1138.54,-4325.05C1248.26,-4314.06 1412.45,-4297.62 1538.1,-4285.03"/>
+<polygon fill="black" stroke="black" points="1538.62,-4288.5 1548.22,-4284.02 1537.92,-4281.53 1538.62,-4288.5"/>
+</g>
+<!-- 71 -->
+<g id="node52" class="node">
+<title>71</title>
+<polygon fill="none" stroke="black" points="1010.18,-4284 562.18,-4284 562.18,-4248 1010.18,-4248 1010.18,-4284"/>
+<text text-anchor="middle" x="786.18" y="-4262.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;71 -->
+<g id="edge37" class="edge">
+<title>0&#45;&gt;71</title>
+<path fill="none" stroke="black" d="M965.43,-4320.85C931.89,-4310.78 888.47,-4297.73 852.76,-4287"/>
+<polygon fill="black" stroke="black" points="853.49,-4283.57 842.9,-4284.04 851.47,-4290.27 853.49,-4283.57"/>
+</g>
+<!-- 98 -->
+<g id="node77" class="node">
+<title>98</title>
+<polygon fill="none" stroke="black" points="544.68,-2484 473.68,-2484 473.68,-2448 544.68,-2448 544.68,-2484"/>
+<text text-anchor="middle" x="509.18" y="-2462.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 0&#45;&gt;98 -->
+<g id="edge69" class="edge">
+<title>0&#45;&gt;98</title>
+<path fill="none" stroke="black" d="M859.11,-4331.63C675.41,-4321.01 399.18,-4289.35 399.18,-4195 399.18,-4195 399.18,-4195 399.18,-2609 399.18,-2558.76 442.74,-2515.52 475.04,-2490.36"/>
+<polygon fill="black" stroke="black" points="477.51,-2492.88 483.39,-2484.06 473.3,-2487.29 477.51,-2492.88"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="1783.18" cy="-3546" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="1783.18" y="-3542.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 64 -->
+<g id="node47" class="node">
+<title>64</title>
+<polygon fill="none" stroke="black" points="1818.68,-3492 1747.68,-3492 1747.68,-3456 1818.68,-3456 1818.68,-3492"/>
+<text text-anchor="middle" x="1783.18" y="-3470.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 1&#45;&gt;64 -->
+<g id="edge31" class="edge">
+<title>1&#45;&gt;64</title>
+<path fill="none" stroke="black" d="M1783.18,-3527.7C1783.18,-3519.98 1783.18,-3510.71 1783.18,-3502.11"/>
+<polygon fill="black" stroke="black" points="1786.68,-3502.1 1783.18,-3492.1 1779.68,-3502.1 1786.68,-3502.1"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="2049.18" cy="-4482" rx="265.65" ry="18"/>
+<text text-anchor="middle" x="2049.18" y="-4478.3" font-family="Times,serif" font-size="14.00">attention.self.query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 34 -->
+<g id="node24" class="node">
+<title>34</title>
+<polygon fill="none" stroke="black" points="2149.68,-4428 1956.68,-4428 1956.68,-4392 2149.68,-4392 2149.68,-4428"/>
+<text text-anchor="middle" x="2053.18" y="-4406.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 2&#45;&gt;34 -->
+<g id="edge2" class="edge">
+<title>2&#45;&gt;34</title>
+<path fill="none" stroke="black" d="M2050.17,-4463.7C2050.61,-4455.98 2051.14,-4446.71 2051.63,-4438.11"/>
+<polygon fill="black" stroke="black" points="2055.12,-4438.29 2052.2,-4428.1 2048.13,-4437.89 2055.12,-4438.29"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="1603.18" cy="-4122" rx="232.86" ry="18"/>
+<text text-anchor="middle" x="1603.18" y="-4118.3" font-family="Times,serif" font-size="14.00">attention.self.query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 41 -->
+<g id="node29" class="node">
+<title>41</title>
+<polygon fill="none" stroke="black" points="2108.68,-4068 2037.68,-4068 2037.68,-4032 2108.68,-4032 2108.68,-4068"/>
+<text text-anchor="middle" x="2073.18" y="-4046.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 3&#45;&gt;41 -->
+<g id="edge9" class="edge">
+<title>3&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M1705.16,-4105.81C1805.34,-4090.89 1953.6,-4068.81 2027.64,-4057.78"/>
+<polygon fill="black" stroke="black" points="2028.2,-4061.24 2037.58,-4056.3 2027.17,-4054.31 2028.2,-4061.24"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="2588.18" cy="-4482" rx="254.55" ry="18"/>
+<text text-anchor="middle" x="2588.18" y="-4478.3" font-family="Times,serif" font-size="14.00">attention.self.key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 48 -->
+<g id="node34" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="2654.68,-4428 2461.68,-4428 2461.68,-4392 2654.68,-4392 2654.68,-4428"/>
+<text text-anchor="middle" x="2558.18" y="-4406.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 4&#45;&gt;48 -->
+<g id="edge14" class="edge">
+<title>4&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M2580.76,-4463.7C2577.38,-4455.81 2573.31,-4446.3 2569.55,-4437.55"/>
+<polygon fill="black" stroke="black" points="2572.66,-4435.92 2565.51,-4428.1 2566.23,-4438.67 2572.66,-4435.92"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="2997.18" cy="-4122" rx="221.76" ry="18"/>
+<text text-anchor="middle" x="2997.18" y="-4118.3" font-family="Times,serif" font-size="14.00">attention.self.key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 53 -->
+<g id="node39" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="2571.68,-4068 2500.68,-4068 2500.68,-4032 2571.68,-4032 2571.68,-4068"/>
+<text text-anchor="middle" x="2536.18" y="-4046.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;53 -->
+<g id="edge21" class="edge">
+<title>5&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M2897.71,-4105.9C2799.95,-4091.05 2655.1,-4069.06 2581.97,-4057.95"/>
+<polygon fill="black" stroke="black" points="2582.13,-4054.44 2571.72,-4056.4 2581.08,-4061.36 2582.13,-4054.44"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="265.18" cy="-4482" rx="265.35" ry="18"/>
+<text text-anchor="middle" x="265.18" y="-4478.3" font-family="Times,serif" font-size="14.00">attention.self.value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 72 -->
+<g id="node53" class="node">
+<title>72</title>
+<polygon fill="none" stroke="black" points="361.68,-4428 168.68,-4428 168.68,-4392 361.68,-4392 361.68,-4428"/>
+<text text-anchor="middle" x="265.18" y="-4406.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;72 -->
+<g id="edge38" class="edge">
+<title>6&#45;&gt;72</title>
+<path fill="none" stroke="black" d="M265.18,-4463.7C265.18,-4455.98 265.18,-4446.71 265.18,-4438.11"/>
+<polygon fill="black" stroke="black" points="268.68,-4438.1 265.18,-4428.1 261.68,-4438.1 268.68,-4438.1"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="659.18" cy="-4122" rx="232.06" ry="18"/>
+<text text-anchor="middle" x="659.18" y="-4118.3" font-family="Times,serif" font-size="14.00">attention.self.value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 77 -->
+<g id="node58" class="node">
+<title>77</title>
+<polygon fill="none" stroke="black" points="1166.68,-4068 1095.68,-4068 1095.68,-4032 1166.68,-4032 1166.68,-4068"/>
+<text text-anchor="middle" x="1131.18" y="-4046.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;77 -->
+<g id="edge45" class="edge">
+<title>7&#45;&gt;77</title>
+<path fill="none" stroke="black" d="M761.59,-4105.81C862.2,-4090.89 1011.09,-4068.81 1085.44,-4057.78"/>
+<polygon fill="black" stroke="black" points="1086.05,-4061.23 1095.43,-4056.3 1085.02,-4054.31 1086.05,-4061.23"/>
+</g>
+<!-- 8 -->
+<g id="node9" class="node">
+<title>8</title>
+<ellipse fill="none" stroke="black" cx="863.18" cy="-3042" rx="282.15" ry="18"/>
+<text text-anchor="middle" x="863.18" y="-3038.3" font-family="Times,serif" font-size="14.00">attention.output.dense.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 89 -->
+<g id="node69" class="node">
+<title>89</title>
+<polygon fill="none" stroke="black" points="959.68,-2988 766.68,-2988 766.68,-2952 959.68,-2952 959.68,-2988"/>
+<text text-anchor="middle" x="863.18" y="-2966.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 8&#45;&gt;89 -->
+<g id="edge57" class="edge">
+<title>8&#45;&gt;89</title>
+<path fill="none" stroke="black" d="M863.18,-3023.7C863.18,-3015.98 863.18,-3006.71 863.18,-2998.11"/>
+<polygon fill="black" stroke="black" points="866.68,-2998.1 863.18,-2988.1 859.68,-2998.1 866.68,-2998.1"/>
+</g>
+<!-- 9 -->
+<g id="node10" class="node">
+<title>9</title>
+<ellipse fill="none" stroke="black" cx="1539.18" cy="-2682" rx="248.86" ry="18"/>
+<text text-anchor="middle" x="1539.18" y="-2678.3" font-family="Times,serif" font-size="14.00">attention.output.dense.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 94 -->
+<g id="node74" class="node">
+<title>94</title>
+<polygon fill="none" stroke="black" points="1086.68,-2628 1015.68,-2628 1015.68,-2592 1086.68,-2592 1086.68,-2628"/>
+<text text-anchor="middle" x="1051.18" y="-2606.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 9&#45;&gt;94 -->
+<g id="edge64" class="edge">
+<title>9&#45;&gt;94</title>
+<path fill="none" stroke="black" d="M1432.69,-2665.72C1328.09,-2650.72 1173.44,-2628.54 1097.16,-2617.6"/>
+<polygon fill="black" stroke="black" points="1097.33,-2614.08 1086.93,-2616.13 1096.33,-2621.01 1097.33,-2614.08"/>
+</g>
+<!-- 10 -->
+<g id="node11" class="node">
+<title>10</title>
+<ellipse fill="none" stroke="black" cx="858.18" cy="-2106" rx="286.75" ry="18"/>
+<text text-anchor="middle" x="858.18" y="-2102.3" font-family="Times,serif" font-size="14.00">attention.output.LayerNorm.weight: Tensor[(768,), float32]</text>
+</g>
+<!-- 110 -->
+<g id="node85" class="node">
+<title>110</title>
+<polygon fill="none" stroke="black" points="805.18,-2052 701.18,-2052 701.18,-2016 805.18,-2016 805.18,-2052"/>
+<text text-anchor="middle" x="753.18" y="-2030.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;110 -->
+<g id="edge81" class="edge">
+<title>10&#45;&gt;110</title>
+<path fill="none" stroke="black" d="M832.76,-2088.05C819.06,-2078.92 801.97,-2067.53 787.15,-2057.65"/>
+<polygon fill="black" stroke="black" points="788.99,-2054.67 778.72,-2052.03 785.1,-2060.49 788.99,-2054.67"/>
+</g>
+<!-- 11 -->
+<g id="node12" class="node">
+<title>11</title>
+<ellipse fill="none" stroke="black" cx="1097.18" cy="-2034" rx="274.05" ry="18"/>
+<text text-anchor="middle" x="1097.18" y="-2030.3" font-family="Times,serif" font-size="14.00">attention.output.LayerNorm.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 111 -->
+<g id="node86" class="node">
+<title>111</title>
+<polygon fill="none" stroke="black" points="1079.68,-1980 1008.68,-1980 1008.68,-1944 1079.68,-1944 1079.68,-1980"/>
+<text text-anchor="middle" x="1044.18" y="-1958.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 11&#45;&gt;111 -->
+<g id="edge83" class="edge">
+<title>11&#45;&gt;111</title>
+<path fill="none" stroke="black" d="M1084.08,-2015.7C1077.79,-2007.39 1070.13,-1997.28 1063.21,-1988.14"/>
+<polygon fill="black" stroke="black" points="1065.95,-1985.96 1057.13,-1980.1 1060.37,-1990.19 1065.95,-1985.96"/>
+</g>
+<!-- 12 -->
+<g id="node13" class="node">
+<title>12</title>
+<ellipse fill="none" stroke="black" cx="1539.18" cy="-2106" rx="271.85" ry="18"/>
+<text text-anchor="middle" x="1539.18" y="-2102.3" font-family="Times,serif" font-size="14.00">intermediate.dense.weight: Tensor[(3072, 768), float32]</text>
+</g>
+<!-- 113 -->
+<g id="node88" class="node">
+<title>113</title>
+<polygon fill="none" stroke="black" points="1635.68,-2052 1442.68,-2052 1442.68,-2016 1635.68,-2016 1635.68,-2052"/>
+<text text-anchor="middle" x="1539.18" y="-2030.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 12&#45;&gt;113 -->
+<g id="edge85" class="edge">
+<title>12&#45;&gt;113</title>
+<path fill="none" stroke="black" d="M1539.18,-2087.7C1539.18,-2079.98 1539.18,-2070.71 1539.18,-2062.11"/>
+<polygon fill="black" stroke="black" points="1542.68,-2062.1 1539.18,-2052.1 1535.68,-2062.1 1542.68,-2062.1"/>
+</g>
+<!-- 13 -->
+<g id="node14" class="node">
+<title>13</title>
+<ellipse fill="none" stroke="black" cx="1862.18" cy="-1746" rx="238.56" ry="18"/>
+<text text-anchor="middle" x="1862.18" y="-1742.3" font-family="Times,serif" font-size="14.00">intermediate.dense.bias: Tensor[(3072,), float32]</text>
+</g>
+<!-- 120 -->
+<g id="node93" class="node">
+<title>120</title>
+<polygon fill="none" stroke="black" points="1767.68,-1692 1696.68,-1692 1696.68,-1656 1767.68,-1656 1767.68,-1692"/>
+<text text-anchor="middle" x="1732.18" y="-1670.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 13&#45;&gt;120 -->
+<g id="edge92" class="edge">
+<title>13&#45;&gt;120</title>
+<path fill="none" stroke="black" d="M1830.71,-1728.05C1813.27,-1718.67 1791.4,-1706.89 1772.7,-1696.82"/>
+<polygon fill="black" stroke="black" points="1774.27,-1693.69 1763.81,-1692.03 1770.95,-1699.85 1774.27,-1693.69"/>
+</g>
+<!-- 14 -->
+<g id="node15" class="node">
+<title>14</title>
+<ellipse fill="none" stroke="black" cx="1360.18" cy="-1458" rx="242.36" ry="18"/>
+<text text-anchor="middle" x="1360.18" y="-1454.3" font-family="Times,serif" font-size="14.00">output.dense.weight: Tensor[(768, 3072), float32]</text>
+</g>
+<!-- 132 -->
+<g id="node100" class="node">
+<title>132</title>
+<polygon fill="none" stroke="black" points="1456.68,-1404 1263.68,-1404 1263.68,-1368 1456.68,-1368 1456.68,-1404"/>
+<text text-anchor="middle" x="1360.18" y="-1382.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 14&#45;&gt;132 -->
+<g id="edge100" class="edge">
+<title>14&#45;&gt;132</title>
+<path fill="none" stroke="black" d="M1360.18,-1439.7C1360.18,-1431.98 1360.18,-1422.71 1360.18,-1414.11"/>
+<polygon fill="black" stroke="black" points="1363.68,-1414.1 1360.18,-1404.1 1356.68,-1414.1 1363.68,-1414.1"/>
+</g>
+<!-- 15 -->
+<g id="node16" class="node">
+<title>15</title>
+<ellipse fill="none" stroke="black" cx="2034.18" cy="-1098" rx="203.36" ry="18"/>
+<text text-anchor="middle" x="2034.18" y="-1094.3" font-family="Times,serif" font-size="14.00">output.dense.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 138 -->
+<g id="node105" class="node">
+<title>138</title>
+<polygon fill="none" stroke="black" points="1626.68,-1044 1555.68,-1044 1555.68,-1008 1626.68,-1008 1626.68,-1044"/>
+<text text-anchor="middle" x="1591.18" y="-1022.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 15&#45;&gt;138 -->
+<g id="edge107" class="edge">
+<title>15&#45;&gt;138</title>
+<path fill="none" stroke="black" d="M1939.41,-1082.03C1846.14,-1067.29 1707.69,-1045.41 1636.73,-1034.2"/>
+<polygon fill="black" stroke="black" points="1637.19,-1030.73 1626.76,-1032.62 1636.09,-1037.64 1637.19,-1030.73"/>
+</g>
+<!-- 16 -->
+<g id="node17" class="node">
+<title>16</title>
+<ellipse fill="none" stroke="black" cx="766.18" cy="-522" rx="241.26" ry="18"/>
+<text text-anchor="middle" x="766.18" y="-518.3" font-family="Times,serif" font-size="14.00">output.LayerNorm.weight: Tensor[(768,), float32]</text>
+</g>
+<!-- 151 -->
+<g id="node116" class="node">
+<title>151</title>
+<polygon fill="none" stroke="black" points="970.18,-468 866.18,-468 866.18,-432 970.18,-432 970.18,-468"/>
+<text text-anchor="middle" x="918.18" y="-446.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;151 -->
+<g id="edge124" class="edge">
+<title>16&#45;&gt;151</title>
+<path fill="none" stroke="black" d="M802.97,-504.05C823.73,-494.5 849.86,-482.46 871.98,-472.27"/>
+<polygon fill="black" stroke="black" points="873.57,-475.39 881.19,-468.03 870.65,-469.04 873.57,-475.39"/>
+</g>
+<!-- 17 -->
+<g id="node18" class="node">
+<title>17</title>
+<ellipse fill="none" stroke="black" cx="619.18" cy="-450" rx="228.56" ry="18"/>
+<text text-anchor="middle" x="619.18" y="-446.3" font-family="Times,serif" font-size="14.00">output.LayerNorm.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 152 -->
+<g id="node117" class="node">
+<title>152</title>
+<polygon fill="none" stroke="black" points="803.68,-396 732.68,-396 732.68,-360 803.68,-360 803.68,-396"/>
+<text text-anchor="middle" x="768.18" y="-374.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 17&#45;&gt;152 -->
+<g id="edge126" class="edge">
+<title>17&#45;&gt;152</title>
+<path fill="none" stroke="black" d="M655.25,-432.05C675.74,-422.42 701.59,-410.28 723.38,-400.05"/>
+<polygon fill="black" stroke="black" points="724.88,-403.21 732.44,-395.79 721.9,-396.87 724.88,-403.21"/>
+</g>
+<!-- 18 -->
+<g id="node19" class="node">
+<title>18</title>
+<ellipse fill="none" stroke="black" cx="482.18" cy="-234" rx="183.87" ry="18"/>
+<text text-anchor="middle" x="482.18" y="-230.3" font-family="Times,serif" font-size="14.00">gr:out:0: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 155 -->
+<g id="node120" class="node">
+<title>155</title>
+<polygon fill="none" stroke="black" points="677.18,-180 573.18,-180 573.18,-144 677.18,-144 677.18,-180"/>
+<text text-anchor="middle" x="625.18" y="-158.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 18&#45;&gt;155 -->
+<g id="edge130" class="edge">
+<title>18&#45;&gt;155</title>
+<path fill="none" stroke="black" d="M516.43,-216.23C535.77,-206.76 560.17,-194.82 580.95,-184.65"/>
+<polygon fill="black" stroke="black" points="582.77,-187.66 590.21,-180.12 579.69,-181.37 582.77,-187.66"/>
+</g>
+<!-- 19 -->
+<g id="node20" class="node">
+<title>19</title>
+<ellipse fill="none" stroke="black" cx="1522.18" cy="-3474" rx="204.16" ry="18"/>
+<text text-anchor="middle" x="1522.18" y="-3470.3" font-family="Times,serif" font-size="14.00">dropout:0: Tensor[(1, 12, 14, 14), float32]</text>
+</g>
+<!-- 67 -->
+<g id="node49" class="node">
+<title>67</title>
+<polygon fill="none" stroke="black" points="1610.18,-3420 1434.18,-3420 1434.18,-3384 1610.18,-3384 1610.18,-3420"/>
+<text text-anchor="middle" x="1522.18" y="-3398.3" font-family="Times,serif" font-size="14.00">multiply(·, 1.1111112)</text>
+</g>
+<!-- 19&#45;&gt;67 -->
+<g id="edge33" class="edge">
+<title>19&#45;&gt;67</title>
+<path fill="none" stroke="black" d="M1522.18,-3455.7C1522.18,-3447.98 1522.18,-3438.71 1522.18,-3430.11"/>
+<polygon fill="black" stroke="black" points="1525.68,-3430.1 1522.18,-3420.1 1518.68,-3430.1 1525.68,-3430.1"/>
+</g>
+<!-- 20 -->
+<g id="node21" class="node">
+<title>20</title>
+<ellipse fill="none" stroke="black" cx="619.18" cy="-2682" rx="192.27" ry="18"/>
+<text text-anchor="middle" x="619.18" y="-2678.3" font-family="Times,serif" font-size="14.00">dropout:1: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 96 -->
+<g id="node75" class="node">
+<title>96</title>
+<polygon fill="none" stroke="black" points="707.18,-2628 531.18,-2628 531.18,-2592 707.18,-2592 707.18,-2628"/>
+<text text-anchor="middle" x="619.18" y="-2606.3" font-family="Times,serif" font-size="14.00">multiply(·, 1.1111112)</text>
+</g>
+<!-- 20&#45;&gt;96 -->
+<g id="edge65" class="edge">
+<title>20&#45;&gt;96</title>
+<path fill="none" stroke="black" d="M619.18,-2663.7C619.18,-2655.98 619.18,-2646.71 619.18,-2638.11"/>
+<polygon fill="black" stroke="black" points="622.68,-2638.1 619.18,-2628.1 615.68,-2638.1 622.68,-2638.1"/>
+</g>
+<!-- 21 -->
+<g id="node22" class="node">
+<title>21</title>
+<ellipse fill="none" stroke="black" cx="1159.18" cy="-1098" rx="192.27" ry="18"/>
+<text text-anchor="middle" x="1159.18" y="-1094.3" font-family="Times,serif" font-size="14.00">dropout:2: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 140 -->
+<g id="node106" class="node">
+<title>140</title>
+<polygon fill="none" stroke="black" points="1247.18,-1044 1071.18,-1044 1071.18,-1008 1247.18,-1008 1247.18,-1044"/>
+<text text-anchor="middle" x="1159.18" y="-1022.3" font-family="Times,serif" font-size="14.00">multiply(·, 1.1111112)</text>
+</g>
+<!-- 21&#45;&gt;140 -->
+<g id="edge108" class="edge">
+<title>21&#45;&gt;140</title>
+<path fill="none" stroke="black" d="M1159.18,-1079.7C1159.18,-1071.98 1159.18,-1062.71 1159.18,-1054.11"/>
+<polygon fill="black" stroke="black" points="1162.68,-1054.1 1159.18,-1044.1 1155.68,-1054.1 1162.68,-1054.1"/>
+</g>
+<!-- 38 -->
+<g id="node27" class="node">
+<title>38</title>
+<polygon fill="none" stroke="black" points="2150.68,-4212 1981.68,-4212 1981.68,-4176 2150.68,-4176 2150.68,-4212"/>
+<text text-anchor="middle" x="2066.18" y="-4190.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 33&#45;&gt;38 -->
+<g id="edge5" class="edge">
+<title>33&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M1450,-4247.99C1615.52,-4233.75 1844.41,-4214.07 1971.5,-4203.14"/>
+<polygon fill="black" stroke="black" points="1971.93,-4206.62 1981.59,-4202.27 1971.33,-4199.64 1971.93,-4206.62"/>
+</g>
+<!-- 36 -->
+<g id="node25" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="2290.68,-4356 1829.68,-4356 1829.68,-4320 2290.68,-4320 2290.68,-4356"/>
+<text text-anchor="middle" x="2060.18" y="-4334.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 34&#45;&gt;36 -->
+<g id="edge3" class="edge">
+<title>34&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M2054.91,-4391.7C2055.68,-4383.98 2056.61,-4374.71 2057.47,-4366.11"/>
+<polygon fill="black" stroke="black" points="2060.95,-4366.4 2058.47,-4356.1 2053.99,-4365.71 2060.95,-4366.4"/>
+</g>
+<!-- 37 -->
+<g id="node26" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="2171.68,-4284 1960.68,-4284 1960.68,-4248 2171.68,-4248 2171.68,-4284"/>
+<text text-anchor="middle" x="2066.18" y="-4262.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge4" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M2061.66,-4319.7C2062.32,-4311.98 2063.12,-4302.71 2063.85,-4294.11"/>
+<polygon fill="black" stroke="black" points="2067.34,-4294.37 2064.71,-4284.1 2060.37,-4293.77 2067.34,-4294.37"/>
+</g>
+<!-- 37&#45;&gt;38 -->
+<g id="edge6" class="edge">
+<title>37&#45;&gt;38</title>
+<path fill="none" stroke="black" d="M2066.18,-4247.7C2066.18,-4239.98 2066.18,-4230.71 2066.18,-4222.11"/>
+<polygon fill="black" stroke="black" points="2069.68,-4222.1 2066.18,-4212.1 2062.68,-4222.1 2069.68,-4222.1"/>
+</g>
+<!-- 40 -->
+<g id="node28" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="2296.68,-4140 1853.68,-4140 1853.68,-4104 2296.68,-4104 2296.68,-4140"/>
+<text text-anchor="middle" x="2075.18" y="-4118.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 38&#45;&gt;40 -->
+<g id="edge7" class="edge">
+<title>38&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M2068.4,-4175.7C2069.39,-4167.98 2070.58,-4158.71 2071.69,-4150.11"/>
+<polygon fill="black" stroke="black" points="2075.17,-4150.47 2072.98,-4140.1 2068.23,-4149.58 2075.17,-4150.47"/>
+</g>
+<!-- 40&#45;&gt;41 -->
+<g id="edge8" class="edge">
+<title>40&#45;&gt;41</title>
+<path fill="none" stroke="black" d="M2074.68,-4103.7C2074.46,-4095.98 2074.2,-4086.71 2073.95,-4078.11"/>
+<polygon fill="black" stroke="black" points="2077.45,-4078 2073.67,-4068.1 2070.45,-4078.2 2077.45,-4078"/>
+</g>
+<!-- 43 -->
+<g id="node30" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="2262.68,-3996 1883.68,-3996 1883.68,-3960 2262.68,-3960 2262.68,-3996"/>
+<text text-anchor="middle" x="2073.18" y="-3974.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 41&#45;&gt;43 -->
+<g id="edge10" class="edge">
+<title>41&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M2073.18,-4031.7C2073.18,-4023.98 2073.18,-4014.71 2073.18,-4006.11"/>
+<polygon fill="black" stroke="black" points="2076.68,-4006.1 2073.18,-3996.1 2069.68,-4006.1 2076.68,-4006.1"/>
+</g>
+<!-- 44 -->
+<g id="node31" class="node">
+<title>44</title>
+<polygon fill="none" stroke="black" points="2187.68,-3924 1958.68,-3924 1958.68,-3888 2187.68,-3888 2187.68,-3924"/>
+<text text-anchor="middle" x="2073.18" y="-3902.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 43&#45;&gt;44 -->
+<g id="edge11" class="edge">
+<title>43&#45;&gt;44</title>
+<path fill="none" stroke="black" d="M2073.18,-3959.7C2073.18,-3951.98 2073.18,-3942.71 2073.18,-3934.11"/>
+<polygon fill="black" stroke="black" points="2076.68,-3934.1 2073.18,-3924.1 2069.68,-3934.1 2076.68,-3934.1"/>
+</g>
+<!-- 46 -->
+<g id="node32" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="2283.68,-3780 1862.68,-3780 1862.68,-3744 2283.68,-3744 2283.68,-3780"/>
+<text text-anchor="middle" x="2073.18" y="-3758.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 64]| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 44&#45;&gt;46 -->
+<g id="edge12" class="edge">
+<title>44&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M2073.18,-3887.87C2073.18,-3863.67 2073.18,-3819.21 2073.18,-3790.39"/>
+<polygon fill="black" stroke="black" points="2076.68,-3790.19 2073.18,-3780.19 2069.68,-3790.19 2076.68,-3790.19"/>
+</g>
+<!-- 59 -->
+<g id="node44" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="2157.68,-3708 1988.68,-3708 1988.68,-3672 2157.68,-3672 2157.68,-3708"/>
+<text text-anchor="middle" x="2073.18" y="-3686.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 46&#45;&gt;59 -->
+<g id="edge26" class="edge">
+<title>46&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M2073.18,-3743.7C2073.18,-3735.98 2073.18,-3726.71 2073.18,-3718.11"/>
+<polygon fill="black" stroke="black" points="2076.68,-3718.1 2073.18,-3708.1 2069.68,-3718.1 2076.68,-3718.1"/>
+</g>
+<!-- 51 -->
+<g id="node37" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="2620.68,-4212 2451.68,-4212 2451.68,-4176 2620.68,-4176 2620.68,-4212"/>
+<text text-anchor="middle" x="2536.18" y="-4190.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 47&#45;&gt;51 -->
+<g id="edge17" class="edge">
+<title>47&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M1916.97,-4247.99C2083.42,-4233.74 2313.64,-4214.04 2441.3,-4203.12"/>
+<polygon fill="black" stroke="black" points="2441.77,-4206.59 2451.44,-4202.25 2441.18,-4199.62 2441.77,-4206.59"/>
+</g>
+<!-- 49 -->
+<g id="node35" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="2784.68,-4356 2323.68,-4356 2323.68,-4320 2784.68,-4320 2784.68,-4356"/>
+<text text-anchor="middle" x="2554.18" y="-4334.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge15" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M2557.19,-4391.7C2556.75,-4383.98 2556.22,-4374.71 2555.73,-4366.11"/>
+<polygon fill="black" stroke="black" points="2559.22,-4365.89 2555.15,-4356.1 2552.23,-4366.29 2559.22,-4365.89"/>
+</g>
+<!-- 50 -->
+<g id="node36" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="2653.68,-4284 2442.68,-4284 2442.68,-4248 2653.68,-4248 2653.68,-4284"/>
+<text text-anchor="middle" x="2548.18" y="-4262.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge16" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M2552.69,-4319.7C2552.03,-4311.98 2551.24,-4302.71 2550.5,-4294.11"/>
+<polygon fill="black" stroke="black" points="2553.98,-4293.77 2549.64,-4284.1 2547.01,-4294.37 2553.98,-4293.77"/>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge18" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M2545.21,-4247.7C2543.89,-4239.98 2542.3,-4230.71 2540.82,-4222.11"/>
+<polygon fill="black" stroke="black" points="2544.25,-4221.37 2539.11,-4212.1 2537.35,-4222.55 2544.25,-4221.37"/>
+</g>
+<!-- 52 -->
+<g id="node38" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="2757.68,-4140 2314.68,-4140 2314.68,-4104 2757.68,-4104 2757.68,-4140"/>
+<text text-anchor="middle" x="2536.18" y="-4118.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge19" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M2536.18,-4175.7C2536.18,-4167.98 2536.18,-4158.71 2536.18,-4150.11"/>
+<polygon fill="black" stroke="black" points="2539.68,-4150.1 2536.18,-4140.1 2532.68,-4150.1 2539.68,-4150.1"/>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge20" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M2536.18,-4103.7C2536.18,-4095.98 2536.18,-4086.71 2536.18,-4078.11"/>
+<polygon fill="black" stroke="black" points="2539.68,-4078.1 2536.18,-4068.1 2532.68,-4078.1 2539.68,-4078.1"/>
+</g>
+<!-- 54 -->
+<g id="node40" class="node">
+<title>54</title>
+<polygon fill="none" stroke="black" points="2717.68,-3996 2338.68,-3996 2338.68,-3960 2717.68,-3960 2717.68,-3996"/>
+<text text-anchor="middle" x="2528.18" y="-3974.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;54 -->
+<g id="edge22" class="edge">
+<title>53&#45;&gt;54</title>
+<path fill="none" stroke="black" d="M2534.2,-4031.7C2533.32,-4023.98 2532.26,-4014.71 2531.28,-4006.11"/>
+<polygon fill="black" stroke="black" points="2534.74,-4005.64 2530.13,-3996.1 2527.79,-4006.44 2534.74,-4005.64"/>
+</g>
+<!-- 55 -->
+<g id="node41" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="2633.68,-3924 2404.68,-3924 2404.68,-3888 2633.68,-3888 2633.68,-3924"/>
+<text text-anchor="middle" x="2519.18" y="-3902.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 3, 1])</text>
+</g>
+<!-- 54&#45;&gt;55 -->
+<g id="edge23" class="edge">
+<title>54&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M2525.95,-3959.7C2524.96,-3951.98 2523.77,-3942.71 2522.66,-3934.11"/>
+<polygon fill="black" stroke="black" points="2526.12,-3933.58 2521.38,-3924.1 2519.18,-3934.47 2526.12,-3933.58"/>
+</g>
+<!-- 57 -->
+<g id="node42" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="2713.68,-3852 2292.68,-3852 2292.68,-3816 2713.68,-3816 2713.68,-3852"/>
+<text text-anchor="middle" x="2503.18" y="-3830.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 64 14]| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 55&#45;&gt;57 -->
+<g id="edge24" class="edge">
+<title>55&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M2515.22,-3887.7C2513.46,-3879.98 2511.34,-3870.71 2509.37,-3862.11"/>
+<polygon fill="black" stroke="black" points="2512.73,-3861.07 2507.09,-3852.1 2505.9,-3862.63 2512.73,-3861.07"/>
+</g>
+<!-- 58 -->
+<g id="node43" class="node">
+<title>58</title>
+<polygon fill="none" stroke="black" points="2544.68,-3780 2333.68,-3780 2333.68,-3744 2544.68,-3744 2544.68,-3780"/>
+<text text-anchor="middle" x="2439.18" y="-3758.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 57&#45;&gt;58 -->
+<g id="edge25" class="edge">
+<title>57&#45;&gt;58</title>
+<path fill="none" stroke="black" d="M2487.36,-3815.7C2479.61,-3807.22 2470.14,-3796.86 2461.65,-3787.58"/>
+<polygon fill="black" stroke="black" points="2464.15,-3785.12 2454.81,-3780.1 2458.98,-3789.85 2464.15,-3785.12"/>
+</g>
+<!-- 58&#45;&gt;59 -->
+<g id="edge27" class="edge">
+<title>58&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M2350.11,-3743.97C2294.92,-3733.41 2224.09,-3719.86 2167.87,-3709.11"/>
+<polygon fill="black" stroke="black" points="2168.39,-3705.65 2157.91,-3707.21 2167.08,-3712.52 2168.39,-3705.65"/>
+</g>
+<!-- 61 -->
+<g id="node45" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="2262.68,-3636 1883.68,-3636 1883.68,-3600 2262.68,-3600 2262.68,-3636"/>
+<text text-anchor="middle" x="2073.18" y="-3614.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 12 14 14]| newshape=..., reverse=0)</text>
+</g>
+<!-- 59&#45;&gt;61 -->
+<g id="edge28" class="edge">
+<title>59&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M2073.18,-3671.7C2073.18,-3663.98 2073.18,-3654.71 2073.18,-3646.11"/>
+<polygon fill="black" stroke="black" points="2076.68,-3646.1 2073.18,-3636.1 2069.68,-3646.1 2076.68,-3646.1"/>
+</g>
+<!-- 63 -->
+<g id="node46" class="node">
+<title>63</title>
+<polygon fill="none" stroke="black" points="2126.68,-3564 2019.68,-3564 2019.68,-3528 2126.68,-3528 2126.68,-3564"/>
+<text text-anchor="middle" x="2073.18" y="-3542.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 61&#45;&gt;63 -->
+<g id="edge29" class="edge">
+<title>61&#45;&gt;63</title>
+<path fill="none" stroke="black" d="M2073.18,-3599.7C2073.18,-3591.98 2073.18,-3582.71 2073.18,-3574.11"/>
+<polygon fill="black" stroke="black" points="2076.68,-3574.1 2073.18,-3564.1 2069.68,-3574.1 2076.68,-3574.1"/>
+</g>
+<!-- 63&#45;&gt;64 -->
+<g id="edge30" class="edge">
+<title>63&#45;&gt;64</title>
+<path fill="none" stroke="black" d="M2019.59,-3530.41C2016.41,-3529.58 2013.25,-3528.77 2010.18,-3528 1947.6,-3512.26 1874.92,-3495.57 1828.99,-3485.21"/>
+<polygon fill="black" stroke="black" points="1829.57,-3481.76 1819.05,-3482.98 1828.03,-3488.59 1829.57,-3481.76"/>
+</g>
+<!-- 65 -->
+<g id="node48" class="node">
+<title>65</title>
+<polygon fill="none" stroke="black" points="1834.18,-3420 1660.18,-3420 1660.18,-3384 1834.18,-3384 1834.18,-3420"/>
+<text text-anchor="middle" x="1747.18" y="-3398.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 64&#45;&gt;65 -->
+<g id="edge32" class="edge">
+<title>64&#45;&gt;65</title>
+<path fill="none" stroke="black" d="M1774.28,-3455.7C1770.18,-3447.73 1765.23,-3438.1 1760.68,-3429.26"/>
+<polygon fill="black" stroke="black" points="1763.66,-3427.4 1755.97,-3420.1 1757.43,-3430.6 1763.66,-3427.4"/>
+</g>
+<!-- 68 -->
+<g id="node50" class="node">
+<title>68</title>
+<polygon fill="none" stroke="black" points="1574.18,-3348 1470.18,-3348 1470.18,-3312 1574.18,-3312 1574.18,-3348"/>
+<text text-anchor="middle" x="1522.18" y="-3326.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 65&#45;&gt;68 -->
+<g id="edge34" class="edge">
+<title>65&#45;&gt;68</title>
+<path fill="none" stroke="black" d="M1692.42,-3383.97C1659.75,-3373.8 1618.15,-3360.86 1584.26,-3350.31"/>
+<polygon fill="black" stroke="black" points="1585.05,-3346.9 1574.46,-3347.27 1582.97,-3353.58 1585.05,-3346.9"/>
+</g>
+<!-- 67&#45;&gt;68 -->
+<g id="edge35" class="edge">
+<title>67&#45;&gt;68</title>
+<path fill="none" stroke="black" d="M1522.18,-3383.7C1522.18,-3375.98 1522.18,-3366.71 1522.18,-3358.11"/>
+<polygon fill="black" stroke="black" points="1525.68,-3358.1 1522.18,-3348.1 1518.68,-3358.1 1525.68,-3358.1"/>
+</g>
+<!-- 70 -->
+<g id="node51" class="node">
+<title>70</title>
+<polygon fill="none" stroke="black" points="1732.68,-3276 1311.68,-3276 1311.68,-3240 1732.68,-3240 1732.68,-3276"/>
+<text text-anchor="middle" x="1522.18" y="-3254.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 14]| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 68&#45;&gt;70 -->
+<g id="edge36" class="edge">
+<title>68&#45;&gt;70</title>
+<path fill="none" stroke="black" d="M1522.18,-3311.7C1522.18,-3303.98 1522.18,-3294.71 1522.18,-3286.11"/>
+<polygon fill="black" stroke="black" points="1525.68,-3286.1 1522.18,-3276.1 1518.68,-3286.1 1525.68,-3286.1"/>
+</g>
+<!-- 82 -->
+<g id="node63" class="node">
+<title>82</title>
+<polygon fill="none" stroke="black" points="1417.68,-3204 1248.68,-3204 1248.68,-3168 1417.68,-3168 1417.68,-3204"/>
+<text text-anchor="middle" x="1333.18" y="-3182.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 70&#45;&gt;82 -->
+<g id="edge50" class="edge">
+<title>70&#45;&gt;82</title>
+<path fill="none" stroke="black" d="M1475.94,-3239.88C1449.65,-3230.14 1416.53,-3217.87 1388.82,-3207.61"/>
+<polygon fill="black" stroke="black" points="1389.85,-3204.26 1379.26,-3204.07 1387.42,-3210.82 1389.85,-3204.26"/>
+</g>
+<!-- 75 -->
+<g id="node56" class="node">
+<title>75</title>
+<polygon fill="none" stroke="black" points="870.68,-4212 701.68,-4212 701.68,-4176 870.68,-4176 870.68,-4212"/>
+<text text-anchor="middle" x="786.18" y="-4190.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 71&#45;&gt;75 -->
+<g id="edge41" class="edge">
+<title>71&#45;&gt;75</title>
+<path fill="none" stroke="black" d="M786.18,-4247.7C786.18,-4239.98 786.18,-4230.71 786.18,-4222.11"/>
+<polygon fill="black" stroke="black" points="789.68,-4222.1 786.18,-4212.1 782.68,-4222.1 789.68,-4222.1"/>
+</g>
+<!-- 73 -->
+<g id="node54" class="node">
+<title>73</title>
+<polygon fill="none" stroke="black" points="495.68,-4356 34.68,-4356 34.68,-4320 495.68,-4320 495.68,-4356"/>
+<text text-anchor="middle" x="265.18" y="-4334.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 72&#45;&gt;73 -->
+<g id="edge39" class="edge">
+<title>72&#45;&gt;73</title>
+<path fill="none" stroke="black" d="M265.18,-4391.7C265.18,-4383.98 265.18,-4374.71 265.18,-4366.11"/>
+<polygon fill="black" stroke="black" points="268.68,-4366.1 265.18,-4356.1 261.68,-4366.1 268.68,-4366.1"/>
+</g>
+<!-- 74 -->
+<g id="node55" class="node">
+<title>74</title>
+<polygon fill="none" stroke="black" points="370.68,-4284 159.68,-4284 159.68,-4248 370.68,-4248 370.68,-4284"/>
+<text text-anchor="middle" x="265.18" y="-4262.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 73&#45;&gt;74 -->
+<g id="edge40" class="edge">
+<title>73&#45;&gt;74</title>
+<path fill="none" stroke="black" d="M265.18,-4319.7C265.18,-4311.98 265.18,-4302.71 265.18,-4294.11"/>
+<polygon fill="black" stroke="black" points="268.68,-4294.1 265.18,-4284.1 261.68,-4294.1 268.68,-4294.1"/>
+</g>
+<!-- 74&#45;&gt;75 -->
+<g id="edge42" class="edge">
+<title>74&#45;&gt;75</title>
+<path fill="none" stroke="black" d="M370.92,-4250.79C464.96,-4238.16 601.19,-4219.85 691.66,-4207.7"/>
+<polygon fill="black" stroke="black" points="692.19,-4211.16 701.63,-4206.36 691.25,-4204.22 692.19,-4211.16"/>
+</g>
+<!-- 76 -->
+<g id="node57" class="node">
+<title>76</title>
+<polygon fill="none" stroke="black" points="1352.68,-4140 909.68,-4140 909.68,-4104 1352.68,-4104 1352.68,-4140"/>
+<text text-anchor="middle" x="1131.18" y="-4118.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 75&#45;&gt;76 -->
+<g id="edge43" class="edge">
+<title>75&#45;&gt;76</title>
+<path fill="none" stroke="black" d="M870.13,-4175.97C920.63,-4165.72 985.02,-4152.66 1037.21,-4142.07"/>
+<polygon fill="black" stroke="black" points="1038.04,-4145.47 1047.14,-4140.05 1036.65,-4138.61 1038.04,-4145.47"/>
+</g>
+<!-- 76&#45;&gt;77 -->
+<g id="edge44" class="edge">
+<title>76&#45;&gt;77</title>
+<path fill="none" stroke="black" d="M1131.18,-4103.7C1131.18,-4095.98 1131.18,-4086.71 1131.18,-4078.11"/>
+<polygon fill="black" stroke="black" points="1134.68,-4078.1 1131.18,-4068.1 1127.68,-4078.1 1134.68,-4078.1"/>
+</g>
+<!-- 78 -->
+<g id="node59" class="node">
+<title>78</title>
+<polygon fill="none" stroke="black" points="1338.68,-3996 959.68,-3996 959.68,-3960 1338.68,-3960 1338.68,-3996"/>
+<text text-anchor="middle" x="1149.18" y="-3974.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 77&#45;&gt;78 -->
+<g id="edge46" class="edge">
+<title>77&#45;&gt;78</title>
+<path fill="none" stroke="black" d="M1135.63,-4031.7C1137.63,-4023.9 1140.05,-4014.51 1142.28,-4005.83"/>
+<polygon fill="black" stroke="black" points="1145.68,-4006.66 1144.78,-3996.1 1138.9,-4004.92 1145.68,-4006.66"/>
+</g>
+<!-- 79 -->
+<g id="node60" class="node">
+<title>79</title>
+<polygon fill="none" stroke="black" points="1265.68,-3924 1036.68,-3924 1036.68,-3888 1265.68,-3888 1265.68,-3924"/>
+<text text-anchor="middle" x="1151.18" y="-3902.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 78&#45;&gt;79 -->
+<g id="edge47" class="edge">
+<title>78&#45;&gt;79</title>
+<path fill="none" stroke="black" d="M1149.67,-3959.7C1149.89,-3951.98 1150.16,-3942.71 1150.4,-3934.11"/>
+<polygon fill="black" stroke="black" points="1153.9,-3934.2 1150.69,-3924.1 1146.9,-3934 1153.9,-3934.2"/>
+</g>
+<!-- 80 -->
+<g id="node61" class="node">
+<title>80</title>
+<polygon fill="none" stroke="black" points="1364.68,-3852 943.68,-3852 943.68,-3816 1364.68,-3816 1364.68,-3852"/>
+<text text-anchor="middle" x="1154.18" y="-3830.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 64]| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 79&#45;&gt;80 -->
+<g id="edge48" class="edge">
+<title>79&#45;&gt;80</title>
+<path fill="none" stroke="black" d="M1151.92,-3887.7C1152.25,-3879.98 1152.65,-3870.71 1153.01,-3862.11"/>
+<polygon fill="black" stroke="black" points="1156.51,-3862.25 1153.44,-3852.1 1149.52,-3861.95 1156.51,-3862.25"/>
+</g>
+<!-- 81 -->
+<g id="node62" class="node">
+<title>81</title>
+<polygon fill="none" stroke="black" points="1300.68,-3708 1089.68,-3708 1089.68,-3672 1300.68,-3672 1300.68,-3708"/>
+<text text-anchor="middle" x="1195.18" y="-3686.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 80&#45;&gt;81 -->
+<g id="edge49" class="edge">
+<title>80&#45;&gt;81</title>
+<path fill="none" stroke="black" d="M1159.12,-3815.87C1166.14,-3791.56 1179.06,-3746.82 1187.38,-3718.01"/>
+<polygon fill="black" stroke="black" points="1190.8,-3718.77 1190.21,-3708.19 1184.08,-3716.83 1190.8,-3718.77"/>
+</g>
+<!-- 81&#45;&gt;82 -->
+<g id="edge51" class="edge">
+<title>81&#45;&gt;82</title>
+<path fill="none" stroke="black" d="M1196.95,-3671.97C1199.58,-3645.34 1204.18,-3592.21 1204.18,-3547 1204.18,-3547 1204.18,-3547 1204.18,-3329 1204.18,-3275.81 1254.38,-3233.72 1292.15,-3209.58"/>
+<polygon fill="black" stroke="black" points="1294.3,-3212.37 1300.96,-3204.13 1290.61,-3206.42 1294.3,-3212.37"/>
+</g>
+<!-- 84 -->
+<g id="node64" class="node">
+<title>84</title>
+<polygon fill="none" stroke="black" points="1522.68,-3132 1143.68,-3132 1143.68,-3096 1522.68,-3096 1522.68,-3132"/>
+<text text-anchor="middle" x="1333.18" y="-3110.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 12 14 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 82&#45;&gt;84 -->
+<g id="edge52" class="edge">
+<title>82&#45;&gt;84</title>
+<path fill="none" stroke="black" d="M1333.18,-3167.7C1333.18,-3159.98 1333.18,-3150.71 1333.18,-3142.11"/>
+<polygon fill="black" stroke="black" points="1336.68,-3142.1 1333.18,-3132.1 1329.68,-3142.1 1336.68,-3142.1"/>
+</g>
+<!-- 85 -->
+<g id="node65" class="node">
+<title>85</title>
+<polygon fill="none" stroke="black" points="1447.68,-3060 1218.68,-3060 1218.68,-3024 1447.68,-3024 1447.68,-3060"/>
+<text text-anchor="middle" x="1333.18" y="-3038.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 84&#45;&gt;85 -->
+<g id="edge53" class="edge">
+<title>84&#45;&gt;85</title>
+<path fill="none" stroke="black" d="M1333.18,-3095.7C1333.18,-3087.98 1333.18,-3078.71 1333.18,-3070.11"/>
+<polygon fill="black" stroke="black" points="1336.68,-3070.1 1333.18,-3060.1 1329.68,-3070.1 1336.68,-3070.1"/>
+</g>
+<!-- 86 -->
+<g id="node66" class="node">
+<title>86</title>
+<polygon fill="none" stroke="black" points="1365.68,-2988 1300.68,-2988 1300.68,-2952 1365.68,-2952 1365.68,-2988"/>
+<text text-anchor="middle" x="1333.18" y="-2966.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 85&#45;&gt;86 -->
+<g id="edge54" class="edge">
+<title>85&#45;&gt;86</title>
+<path fill="none" stroke="black" d="M1333.18,-3023.7C1333.18,-3015.98 1333.18,-3006.71 1333.18,-2998.11"/>
+<polygon fill="black" stroke="black" points="1336.68,-2998.1 1333.18,-2988.1 1329.68,-2998.1 1336.68,-2998.1"/>
+</g>
+<!-- 87 -->
+<g id="node67" class="node">
+<title>87</title>
+<polygon fill="none" stroke="black" points="1554.68,-2916 1111.68,-2916 1111.68,-2880 1554.68,-2880 1554.68,-2916"/>
+<text text-anchor="middle" x="1333.18" y="-2894.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 86&#45;&gt;87 -->
+<g id="edge55" class="edge">
+<title>86&#45;&gt;87</title>
+<path fill="none" stroke="black" d="M1333.18,-2951.7C1333.18,-2943.98 1333.18,-2934.71 1333.18,-2926.11"/>
+<polygon fill="black" stroke="black" points="1336.68,-2926.1 1333.18,-2916.1 1329.68,-2926.1 1336.68,-2926.1"/>
+</g>
+<!-- 88 -->
+<g id="node68" class="node">
+<title>88</title>
+<polygon fill="none" stroke="black" points="1526.18,-2844 1078.18,-2844 1078.18,-2808 1526.18,-2808 1526.18,-2844"/>
+<text text-anchor="middle" x="1302.18" y="-2822.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 87&#45;&gt;88 -->
+<g id="edge56" class="edge">
+<title>87&#45;&gt;88</title>
+<path fill="none" stroke="black" d="M1325.51,-2879.7C1321.98,-2871.73 1317.72,-2862.1 1313.81,-2853.26"/>
+<polygon fill="black" stroke="black" points="1317,-2851.83 1309.75,-2844.1 1310.6,-2854.67 1317,-2851.83"/>
+</g>
+<!-- 92 -->
+<g id="node72" class="node">
+<title>92</title>
+<polygon fill="none" stroke="black" points="1135.68,-2772 966.68,-2772 966.68,-2736 1135.68,-2736 1135.68,-2772"/>
+<text text-anchor="middle" x="1051.18" y="-2750.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 88&#45;&gt;92 -->
+<g id="edge60" class="edge">
+<title>88&#45;&gt;92</title>
+<path fill="none" stroke="black" d="M1241.1,-2807.97C1205.26,-2797.97 1159.82,-2785.3 1122.37,-2774.85"/>
+<polygon fill="black" stroke="black" points="1123.1,-2771.43 1112.53,-2772.11 1121.22,-2778.17 1123.1,-2771.43"/>
+</g>
+<!-- 90 -->
+<g id="node70" class="node">
+<title>90</title>
+<polygon fill="none" stroke="black" points="1093.68,-2916 632.68,-2916 632.68,-2880 1093.68,-2880 1093.68,-2916"/>
+<text text-anchor="middle" x="863.18" y="-2894.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 89&#45;&gt;90 -->
+<g id="edge58" class="edge">
+<title>89&#45;&gt;90</title>
+<path fill="none" stroke="black" d="M863.18,-2951.7C863.18,-2943.98 863.18,-2934.71 863.18,-2926.11"/>
+<polygon fill="black" stroke="black" points="866.68,-2926.1 863.18,-2916.1 859.68,-2926.1 866.68,-2926.1"/>
+</g>
+<!-- 91 -->
+<g id="node71" class="node">
+<title>91</title>
+<polygon fill="none" stroke="black" points="998.68,-2844 787.68,-2844 787.68,-2808 998.68,-2808 998.68,-2844"/>
+<text text-anchor="middle" x="893.18" y="-2822.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 90&#45;&gt;91 -->
+<g id="edge59" class="edge">
+<title>90&#45;&gt;91</title>
+<path fill="none" stroke="black" d="M870.59,-2879.7C873.97,-2871.81 878.05,-2862.3 881.8,-2853.55"/>
+<polygon fill="black" stroke="black" points="885.12,-2854.67 885.85,-2844.1 878.69,-2851.92 885.12,-2854.67"/>
+</g>
+<!-- 91&#45;&gt;92 -->
+<g id="edge61" class="edge">
+<title>91&#45;&gt;92</title>
+<path fill="none" stroke="black" d="M931.83,-2807.88C953.43,-2798.31 980.53,-2786.3 1003.44,-2776.15"/>
+<polygon fill="black" stroke="black" points="1004.93,-2779.32 1012.65,-2772.07 1002.09,-2772.92 1004.93,-2779.32"/>
+</g>
+<!-- 93 -->
+<g id="node73" class="node">
+<title>93</title>
+<polygon fill="none" stroke="black" points="1272.68,-2700 829.68,-2700 829.68,-2664 1272.68,-2664 1272.68,-2700"/>
+<text text-anchor="middle" x="1051.18" y="-2678.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 92&#45;&gt;93 -->
+<g id="edge62" class="edge">
+<title>92&#45;&gt;93</title>
+<path fill="none" stroke="black" d="M1051.18,-2735.7C1051.18,-2727.98 1051.18,-2718.71 1051.18,-2710.11"/>
+<polygon fill="black" stroke="black" points="1054.68,-2710.1 1051.18,-2700.1 1047.68,-2710.1 1054.68,-2710.1"/>
+</g>
+<!-- 93&#45;&gt;94 -->
+<g id="edge63" class="edge">
+<title>93&#45;&gt;94</title>
+<path fill="none" stroke="black" d="M1051.18,-2663.7C1051.18,-2655.98 1051.18,-2646.71 1051.18,-2638.11"/>
+<polygon fill="black" stroke="black" points="1054.68,-2638.1 1051.18,-2628.1 1047.68,-2638.1 1054.68,-2638.1"/>
+</g>
+<!-- 97 -->
+<g id="node76" class="node">
+<title>97</title>
+<polygon fill="none" stroke="black" points="671.18,-2556 567.18,-2556 567.18,-2520 671.18,-2520 671.18,-2556"/>
+<text text-anchor="middle" x="619.18" y="-2534.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 94&#45;&gt;97 -->
+<g id="edge66" class="edge">
+<title>94&#45;&gt;97</title>
+<path fill="none" stroke="black" d="M1015.46,-2603.21C942,-2591.31 773.2,-2563.96 681.4,-2549.08"/>
+<polygon fill="black" stroke="black" points="681.84,-2545.61 671.4,-2547.46 680.72,-2552.52 681.84,-2545.61"/>
+</g>
+<!-- 96&#45;&gt;97 -->
+<g id="edge67" class="edge">
+<title>96&#45;&gt;97</title>
+<path fill="none" stroke="black" d="M619.18,-2591.7C619.18,-2583.98 619.18,-2574.71 619.18,-2566.11"/>
+<polygon fill="black" stroke="black" points="622.68,-2566.1 619.18,-2556.1 615.68,-2566.1 622.68,-2566.1"/>
+</g>
+<!-- 97&#45;&gt;98 -->
+<g id="edge68" class="edge">
+<title>97&#45;&gt;98</title>
+<path fill="none" stroke="black" d="M592.27,-2519.88C578.02,-2510.81 560.33,-2499.55 544.94,-2489.76"/>
+<polygon fill="black" stroke="black" points="546.5,-2486.61 536.19,-2484.19 542.75,-2492.51 546.5,-2486.61"/>
+</g>
+<!-- 100 -->
+<g id="node78" class="node">
+<title>100</title>
+<polygon fill="none" stroke="black" points="577.68,-2412 250.68,-2412 250.68,-2376 577.68,-2376 577.68,-2412"/>
+<text text-anchor="middle" x="414.18" y="-2390.3" font-family="Times,serif" font-size="14.00">mean(·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 98&#45;&gt;100 -->
+<g id="edge70" class="edge">
+<title>98&#45;&gt;100</title>
+<path fill="none" stroke="black" d="M485.69,-2447.7C473.61,-2438.8 458.72,-2427.82 445.66,-2418.2"/>
+<polygon fill="black" stroke="black" points="447.52,-2415.22 437.39,-2412.1 443.36,-2420.85 447.52,-2415.22"/>
+</g>
+<!-- 101 -->
+<g id="node79" class="node">
+<title>101</title>
+<polygon fill="none" stroke="black" points="435.68,-2340 330.68,-2340 330.68,-2304 435.68,-2304 435.68,-2340"/>
+<text text-anchor="middle" x="383.18" y="-2318.3" font-family="Times,serif" font-size="14.00">subtract(·, ·)</text>
+</g>
+<!-- 98&#45;&gt;101 -->
+<g id="edge71" class="edge">
+<title>98&#45;&gt;101</title>
+<path fill="none" stroke="black" d="M473.5,-2463.36C407.45,-2459.5 270.19,-2447.32 241.18,-2412 231.02,-2399.64 232.05,-2389.14 241.18,-2376 259.1,-2350.19 291.2,-2336.92 320.35,-2330.12"/>
+<polygon fill="black" stroke="black" points="321.45,-2333.46 330.51,-2327.97 320,-2326.61 321.45,-2333.46"/>
+</g>
+<!-- 104 -->
+<g id="node80" class="node">
+<title>104</title>
+<polygon fill="none" stroke="black" points="959.68,-2412 632.68,-2412 632.68,-2376 959.68,-2376 959.68,-2412"/>
+<text text-anchor="middle" x="796.18" y="-2390.3" font-family="Times,serif" font-size="14.00">mean(·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 98&#45;&gt;104 -->
+<g id="edge73" class="edge">
+<title>98&#45;&gt;104</title>
+<path fill="none" stroke="black" d="M544.94,-2456.28C587.61,-2445.87 660.23,-2428.16 716.45,-2414.45"/>
+<polygon fill="black" stroke="black" points="717.5,-2417.79 726.39,-2412.02 715.84,-2410.99 717.5,-2417.79"/>
+</g>
+<!-- 105 -->
+<g id="node81" class="node">
+<title>105</title>
+<polygon fill="none" stroke="black" points="816.18,-2340 454.18,-2340 454.18,-2304 816.18,-2304 816.18,-2340"/>
+<text text-anchor="middle" x="635.18" y="-2318.3" font-family="Times,serif" font-size="14.00">variance(·, ·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 98&#45;&gt;105 -->
+<g id="edge74" class="edge">
+<title>98&#45;&gt;105</title>
+<path fill="none" stroke="black" d="M540.98,-2447.78C555.96,-2438.57 573.33,-2426.2 586.18,-2412 602.91,-2393.51 616.1,-2368.31 624.56,-2349.44"/>
+<polygon fill="black" stroke="black" points="627.79,-2350.78 628.55,-2340.22 621.37,-2348.01 627.79,-2350.78"/>
+</g>
+<!-- 100&#45;&gt;101 -->
+<g id="edge72" class="edge">
+<title>100&#45;&gt;101</title>
+<path fill="none" stroke="black" d="M406.51,-2375.7C402.98,-2367.73 398.72,-2358.1 394.81,-2349.26"/>
+<polygon fill="black" stroke="black" points="398,-2347.83 390.75,-2340.1 391.6,-2350.67 398,-2347.83"/>
+</g>
+<!-- 109 -->
+<g id="node84" class="node">
+<title>109</title>
+<polygon fill="none" stroke="black" points="553.68,-2124 464.68,-2124 464.68,-2088 553.68,-2088 553.68,-2124"/>
+<text text-anchor="middle" x="509.18" y="-2102.3" font-family="Times,serif" font-size="14.00">divide(·, ·)</text>
+</g>
+<!-- 101&#45;&gt;109 -->
+<g id="edge78" class="edge">
+<title>101&#45;&gt;109</title>
+<path fill="none" stroke="black" d="M393.27,-2303.85C415.3,-2266.44 467.77,-2177.32 493.88,-2132.98"/>
+<polygon fill="black" stroke="black" points="496.97,-2134.63 499.03,-2124.23 490.94,-2131.07 496.97,-2134.63"/>
+</g>
+<!-- 104&#45;&gt;105 -->
+<g id="edge75" class="edge">
+<title>104&#45;&gt;105</title>
+<path fill="none" stroke="black" d="M756.79,-2375.88C734.78,-2366.31 707.17,-2354.3 683.82,-2344.15"/>
+<polygon fill="black" stroke="black" points="685,-2340.85 674.43,-2340.07 682.21,-2347.27 685,-2340.85"/>
+</g>
+<!-- 106 -->
+<g id="node82" class="node">
+<title>106</title>
+<polygon fill="none" stroke="black" points="643.68,-2268 582.68,-2268 582.68,-2232 643.68,-2232 643.68,-2268"/>
+<text text-anchor="middle" x="613.18" y="-2246.3" font-family="Times,serif" font-size="14.00">sqrt(·)</text>
+</g>
+<!-- 105&#45;&gt;106 -->
+<g id="edge76" class="edge">
+<title>105&#45;&gt;106</title>
+<path fill="none" stroke="black" d="M629.74,-2303.7C627.29,-2295.9 624.34,-2286.51 621.61,-2277.83"/>
+<polygon fill="black" stroke="black" points="624.89,-2276.59 618.55,-2268.1 618.21,-2278.69 624.89,-2276.59"/>
+</g>
+<!-- 108 -->
+<g id="node83" class="node">
+<title>108</title>
+<polygon fill="none" stroke="black" points="645.68,-2196 538.68,-2196 538.68,-2160 645.68,-2160 645.68,-2196"/>
+<text text-anchor="middle" x="592.18" y="-2174.3" font-family="Times,serif" font-size="14.00">add(·, 1e&#45;12)</text>
+</g>
+<!-- 106&#45;&gt;108 -->
+<g id="edge77" class="edge">
+<title>106&#45;&gt;108</title>
+<path fill="none" stroke="black" d="M607.99,-2231.7C605.65,-2223.9 602.83,-2214.51 600.23,-2205.83"/>
+<polygon fill="black" stroke="black" points="603.53,-2204.68 597.31,-2196.1 596.83,-2206.69 603.53,-2204.68"/>
+</g>
+<!-- 108&#45;&gt;109 -->
+<g id="edge79" class="edge">
+<title>108&#45;&gt;109</title>
+<path fill="none" stroke="black" d="M571.66,-2159.7C561.31,-2150.97 548.59,-2140.24 537.34,-2130.75"/>
+<polygon fill="black" stroke="black" points="539.36,-2127.88 529.46,-2124.1 534.85,-2133.23 539.36,-2127.88"/>
+</g>
+<!-- 109&#45;&gt;110 -->
+<g id="edge80" class="edge">
+<title>109&#45;&gt;110</title>
+<path fill="none" stroke="black" d="M553.96,-2090.5C556.73,-2089.64 559.49,-2088.8 562.18,-2088 605.15,-2075.15 653.89,-2061.64 691.37,-2051.48"/>
+<polygon fill="black" stroke="black" points="692.39,-2054.83 701.13,-2048.84 690.57,-2048.07 692.39,-2054.83"/>
+</g>
+<!-- 110&#45;&gt;111 -->
+<g id="edge82" class="edge">
+<title>110&#45;&gt;111</title>
+<path fill="none" stroke="black" d="M805.38,-2018.28C808.35,-2017.5 811.3,-2016.73 814.18,-2016 878,-1999.76 952.33,-1983.03 998.82,-1972.81"/>
+<polygon fill="black" stroke="black" points="999.57,-1976.23 1008.59,-1970.67 998.08,-1969.4 999.57,-1976.23"/>
+</g>
+<!-- 112 -->
+<g id="node87" class="node">
+<title>112</title>
+<polygon fill="none" stroke="black" points="1415.18,-1908 967.18,-1908 967.18,-1872 1415.18,-1872 1415.18,-1908"/>
+<text text-anchor="middle" x="1191.18" y="-1886.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 111&#45;&gt;112 -->
+<g id="edge84" class="edge">
+<title>111&#45;&gt;112</title>
+<path fill="none" stroke="black" d="M1079.76,-1944.05C1099.75,-1934.54 1124.88,-1922.57 1146.22,-1912.41"/>
+<polygon fill="black" stroke="black" points="1147.89,-1915.49 1155.41,-1908.03 1144.88,-1909.17 1147.89,-1915.49"/>
+</g>
+<!-- 142 -->
+<g id="node108" class="node">
+<title>142</title>
+<polygon fill="none" stroke="black" points="1084.68,-900 1013.68,-900 1013.68,-864 1084.68,-864 1084.68,-900"/>
+<text text-anchor="middle" x="1049.18" y="-878.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 111&#45;&gt;142 -->
+<g id="edge112" class="edge">
+<title>111&#45;&gt;142</title>
+<path fill="none" stroke="black" d="M1008.29,-1947.01C990.62,-1938.34 970.51,-1925.48 958.18,-1908 934.86,-1874.95 939.18,-1859.45 939.18,-1819 939.18,-1819 939.18,-1819 939.18,-1025 939.18,-974.76 982.74,-931.52 1015.04,-906.36"/>
+<polygon fill="black" stroke="black" points="1017.51,-908.88 1023.39,-900.06 1013.3,-903.29 1017.51,-908.88"/>
+</g>
+<!-- 117 -->
+<g id="node91" class="node">
+<title>117</title>
+<polygon fill="none" stroke="black" points="1454.68,-1836 1285.68,-1836 1285.68,-1800 1454.68,-1800 1454.68,-1836"/>
+<text text-anchor="middle" x="1370.18" y="-1814.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 112&#45;&gt;117 -->
+<g id="edge88" class="edge">
+<title>112&#45;&gt;117</title>
+<path fill="none" stroke="black" d="M1234.96,-1871.88C1259.76,-1862.18 1290.96,-1849.98 1317.13,-1839.74"/>
+<polygon fill="black" stroke="black" points="1318.49,-1842.97 1326.53,-1836.07 1315.94,-1836.45 1318.49,-1842.97"/>
+</g>
+<!-- 115 -->
+<g id="node89" class="node">
+<title>115</title>
+<polygon fill="none" stroke="black" points="1737.68,-1980 1340.68,-1980 1340.68,-1944 1737.68,-1944 1737.68,-1980"/>
+<text text-anchor="middle" x="1539.18" y="-1958.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 &#160;768 3072]| newshape=..., reverse=0)</text>
+</g>
+<!-- 113&#45;&gt;115 -->
+<g id="edge86" class="edge">
+<title>113&#45;&gt;115</title>
+<path fill="none" stroke="black" d="M1539.18,-2015.7C1539.18,-2007.98 1539.18,-1998.71 1539.18,-1990.11"/>
+<polygon fill="black" stroke="black" points="1542.68,-1990.1 1539.18,-1980.1 1535.68,-1990.1 1542.68,-1990.1"/>
+</g>
+<!-- 116 -->
+<g id="node90" class="node">
+<title>116</title>
+<polygon fill="none" stroke="black" points="1644.68,-1908 1433.68,-1908 1433.68,-1872 1644.68,-1872 1644.68,-1908"/>
+<text text-anchor="middle" x="1539.18" y="-1886.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 115&#45;&gt;116 -->
+<g id="edge87" class="edge">
+<title>115&#45;&gt;116</title>
+<path fill="none" stroke="black" d="M1539.18,-1943.7C1539.18,-1935.98 1539.18,-1926.71 1539.18,-1918.11"/>
+<polygon fill="black" stroke="black" points="1542.68,-1918.1 1539.18,-1908.1 1535.68,-1918.1 1542.68,-1918.1"/>
+</g>
+<!-- 116&#45;&gt;117 -->
+<g id="edge89" class="edge">
+<title>116&#45;&gt;117</title>
+<path fill="none" stroke="black" d="M1497.83,-1871.88C1474.63,-1862.26 1445.49,-1850.19 1420.91,-1840.01"/>
+<polygon fill="black" stroke="black" points="1421.96,-1836.66 1411.38,-1836.07 1419.28,-1843.13 1421.96,-1836.66"/>
+</g>
+<!-- 119 -->
+<g id="node92" class="node">
+<title>119</title>
+<polygon fill="none" stroke="black" points="1605.18,-1764 1135.18,-1764 1135.18,-1728 1605.18,-1728 1605.18,-1764"/>
+<text text-anchor="middle" x="1370.18" y="-1742.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#160;1 &#160;&#160;14 3072]| newshape=[1, 14, 3072], reverse=0)</text>
+</g>
+<!-- 117&#45;&gt;119 -->
+<g id="edge90" class="edge">
+<title>117&#45;&gt;119</title>
+<path fill="none" stroke="black" d="M1370.18,-1799.7C1370.18,-1791.98 1370.18,-1782.71 1370.18,-1774.11"/>
+<polygon fill="black" stroke="black" points="1373.68,-1774.1 1370.18,-1764.1 1366.68,-1774.1 1373.68,-1774.1"/>
+</g>
+<!-- 119&#45;&gt;120 -->
+<g id="edge91" class="edge">
+<title>119&#45;&gt;120</title>
+<path fill="none" stroke="black" d="M1458.27,-1727.97C1530.33,-1714.03 1629.35,-1694.88 1686.47,-1683.84"/>
+<polygon fill="black" stroke="black" points="1687.14,-1687.27 1696.3,-1681.94 1685.81,-1680.4 1687.14,-1687.27"/>
+</g>
+<!-- 124 -->
+<g id="node94" class="node">
+<title>124</title>
+<polygon fill="none" stroke="black" points="1893.68,-1620 1708.68,-1620 1708.68,-1584 1893.68,-1584 1893.68,-1620"/>
+<text text-anchor="middle" x="1801.18" y="-1598.3" font-family="Times,serif" font-size="14.00">multiply(·, 0.70710677)</text>
+</g>
+<!-- 120&#45;&gt;124 -->
+<g id="edge93" class="edge">
+<title>120&#45;&gt;124</title>
+<path fill="none" stroke="black" d="M1749.23,-1655.7C1757.67,-1647.14 1768,-1636.66 1777.22,-1627.3"/>
+<polygon fill="black" stroke="black" points="1779.79,-1629.68 1784.32,-1620.1 1774.8,-1624.77 1779.79,-1629.68"/>
+</g>
+<!-- 129 -->
+<g id="node98" class="node">
+<title>129</title>
+<polygon fill="none" stroke="black" points="1773.18,-1332 1669.18,-1332 1669.18,-1296 1773.18,-1296 1773.18,-1332"/>
+<text text-anchor="middle" x="1721.18" y="-1310.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 120&#45;&gt;129 -->
+<g id="edge97" class="edge">
+<title>120&#45;&gt;129</title>
+<path fill="none" stroke="black" d="M1719.19,-1655.7C1712.29,-1645.73 1704.22,-1632.69 1699.18,-1620 1684.27,-1582.5 1681.18,-1571.36 1681.18,-1531 1681.18,-1531 1681.18,-1531 1681.18,-1457 1681.18,-1415.48 1697.21,-1369.61 1708.98,-1341.54"/>
+<polygon fill="black" stroke="black" points="1712.33,-1342.62 1713.08,-1332.05 1705.9,-1339.84 1712.33,-1342.62"/>
+</g>
+<!-- 125 -->
+<g id="node95" class="node">
+<title>125</title>
+<polygon fill="none" stroke="black" points="1820.18,-1548 1766.18,-1548 1766.18,-1512 1820.18,-1512 1820.18,-1548"/>
+<text text-anchor="middle" x="1793.18" y="-1526.3" font-family="Times,serif" font-size="14.00">erf(·)</text>
+</g>
+<!-- 124&#45;&gt;125 -->
+<g id="edge94" class="edge">
+<title>124&#45;&gt;125</title>
+<path fill="none" stroke="black" d="M1799.2,-1583.7C1798.32,-1575.98 1797.26,-1566.71 1796.28,-1558.11"/>
+<polygon fill="black" stroke="black" points="1799.74,-1557.64 1795.13,-1548.1 1792.79,-1558.44 1799.74,-1557.64"/>
+</g>
+<!-- 127 -->
+<g id="node96" class="node">
+<title>127</title>
+<polygon fill="none" stroke="black" points="1846.18,-1476 1724.18,-1476 1724.18,-1440 1846.18,-1440 1846.18,-1476"/>
+<text text-anchor="middle" x="1785.18" y="-1454.3" font-family="Times,serif" font-size="14.00">multiply(·, 0.5)</text>
+</g>
+<!-- 125&#45;&gt;127 -->
+<g id="edge95" class="edge">
+<title>125&#45;&gt;127</title>
+<path fill="none" stroke="black" d="M1791.2,-1511.7C1790.32,-1503.98 1789.26,-1494.71 1788.28,-1486.11"/>
+<polygon fill="black" stroke="black" points="1791.74,-1485.64 1787.13,-1476.1 1784.79,-1486.44 1791.74,-1485.64"/>
+</g>
+<!-- 128 -->
+<g id="node97" class="node">
+<title>128</title>
+<polygon fill="none" stroke="black" points="1806.68,-1404 1717.68,-1404 1717.68,-1368 1806.68,-1368 1806.68,-1404"/>
+<text text-anchor="middle" x="1762.18" y="-1382.3" font-family="Times,serif" font-size="14.00">add(0.5, ·)</text>
+</g>
+<!-- 127&#45;&gt;128 -->
+<g id="edge96" class="edge">
+<title>127&#45;&gt;128</title>
+<path fill="none" stroke="black" d="M1779.49,-1439.7C1776.93,-1431.9 1773.84,-1422.51 1770.99,-1413.83"/>
+<polygon fill="black" stroke="black" points="1774.24,-1412.51 1767.8,-1404.1 1767.59,-1414.7 1774.24,-1412.51"/>
+</g>
+<!-- 128&#45;&gt;129 -->
+<g id="edge98" class="edge">
+<title>128&#45;&gt;129</title>
+<path fill="none" stroke="black" d="M1752.04,-1367.7C1747.32,-1359.64 1741.61,-1349.89 1736.39,-1340.98"/>
+<polygon fill="black" stroke="black" points="1739.27,-1338.96 1731.19,-1332.1 1733.23,-1342.5 1739.27,-1338.96"/>
+</g>
+<!-- 131 -->
+<g id="node99" class="node">
+<title>131</title>
+<polygon fill="none" stroke="black" points="1958.68,-1260 1483.68,-1260 1483.68,-1224 1958.68,-1224 1958.68,-1260"/>
+<text text-anchor="middle" x="1721.18" y="-1238.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 &#160;&#160;14 3072]| newshape=[&#45;1, 14, 3072], reverse=0)</text>
+</g>
+<!-- 129&#45;&gt;131 -->
+<g id="edge99" class="edge">
+<title>129&#45;&gt;131</title>
+<path fill="none" stroke="black" d="M1721.18,-1295.7C1721.18,-1287.98 1721.18,-1278.71 1721.18,-1270.11"/>
+<polygon fill="black" stroke="black" points="1724.68,-1270.1 1721.18,-1260.1 1717.68,-1270.1 1724.68,-1270.1"/>
+</g>
+<!-- 136 -->
+<g id="node103" class="node">
+<title>136</title>
+<polygon fill="none" stroke="black" points="1675.68,-1188 1506.68,-1188 1506.68,-1152 1675.68,-1152 1675.68,-1188"/>
+<text text-anchor="middle" x="1591.18" y="-1166.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 131&#45;&gt;136 -->
+<g id="edge103" class="edge">
+<title>131&#45;&gt;136</title>
+<path fill="none" stroke="black" d="M1689.37,-1223.88C1672.06,-1214.55 1650.46,-1202.92 1631.92,-1192.94"/>
+<polygon fill="black" stroke="black" points="1633.56,-1189.85 1623.1,-1188.19 1630.25,-1196.01 1633.56,-1189.85"/>
+</g>
+<!-- 134 -->
+<g id="node101" class="node">
+<title>134</title>
+<polygon fill="none" stroke="black" points="1558.68,-1332 1161.68,-1332 1161.68,-1296 1558.68,-1296 1558.68,-1332"/>
+<text text-anchor="middle" x="1360.18" y="-1310.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 3072 &#160;768]| newshape=..., reverse=0)</text>
+</g>
+<!-- 132&#45;&gt;134 -->
+<g id="edge101" class="edge">
+<title>132&#45;&gt;134</title>
+<path fill="none" stroke="black" d="M1360.18,-1367.7C1360.18,-1359.98 1360.18,-1350.71 1360.18,-1342.11"/>
+<polygon fill="black" stroke="black" points="1363.68,-1342.1 1360.18,-1332.1 1356.68,-1342.1 1363.68,-1342.1"/>
+</g>
+<!-- 135 -->
+<g id="node102" class="node">
+<title>135</title>
+<polygon fill="none" stroke="black" points="1465.68,-1260 1254.68,-1260 1254.68,-1224 1465.68,-1224 1465.68,-1260"/>
+<text text-anchor="middle" x="1360.18" y="-1238.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 134&#45;&gt;135 -->
+<g id="edge102" class="edge">
+<title>134&#45;&gt;135</title>
+<path fill="none" stroke="black" d="M1360.18,-1295.7C1360.18,-1287.98 1360.18,-1278.71 1360.18,-1270.11"/>
+<polygon fill="black" stroke="black" points="1363.68,-1270.1 1360.18,-1260.1 1356.68,-1270.1 1363.68,-1270.1"/>
+</g>
+<!-- 135&#45;&gt;136 -->
+<g id="edge104" class="edge">
+<title>135&#45;&gt;136</title>
+<path fill="none" stroke="black" d="M1416.39,-1223.97C1449.09,-1214.06 1490.48,-1201.51 1524.78,-1191.12"/>
+<polygon fill="black" stroke="black" points="1526.16,-1194.36 1534.71,-1188.11 1524.13,-1187.66 1526.16,-1194.36"/>
+</g>
+<!-- 137 -->
+<g id="node104" class="node">
+<title>137</title>
+<polygon fill="none" stroke="black" points="1812.68,-1116 1369.68,-1116 1369.68,-1080 1812.68,-1080 1812.68,-1116"/>
+<text text-anchor="middle" x="1591.18" y="-1094.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 136&#45;&gt;137 -->
+<g id="edge105" class="edge">
+<title>136&#45;&gt;137</title>
+<path fill="none" stroke="black" d="M1591.18,-1151.7C1591.18,-1143.98 1591.18,-1134.71 1591.18,-1126.11"/>
+<polygon fill="black" stroke="black" points="1594.68,-1126.1 1591.18,-1116.1 1587.68,-1126.1 1594.68,-1126.1"/>
+</g>
+<!-- 137&#45;&gt;138 -->
+<g id="edge106" class="edge">
+<title>137&#45;&gt;138</title>
+<path fill="none" stroke="black" d="M1591.18,-1079.7C1591.18,-1071.98 1591.18,-1062.71 1591.18,-1054.11"/>
+<polygon fill="black" stroke="black" points="1594.68,-1054.1 1591.18,-1044.1 1587.68,-1054.1 1594.68,-1054.1"/>
+</g>
+<!-- 141 -->
+<g id="node107" class="node">
+<title>141</title>
+<polygon fill="none" stroke="black" points="1211.18,-972 1107.18,-972 1107.18,-936 1211.18,-936 1211.18,-972"/>
+<text text-anchor="middle" x="1159.18" y="-950.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 138&#45;&gt;141 -->
+<g id="edge109" class="edge">
+<title>138&#45;&gt;141</title>
+<path fill="none" stroke="black" d="M1555.46,-1019.21C1482,-1007.31 1313.2,-979.96 1221.4,-965.08"/>
+<polygon fill="black" stroke="black" points="1221.84,-961.61 1211.4,-963.46 1220.72,-968.52 1221.84,-961.61"/>
+</g>
+<!-- 140&#45;&gt;141 -->
+<g id="edge110" class="edge">
+<title>140&#45;&gt;141</title>
+<path fill="none" stroke="black" d="M1159.18,-1007.7C1159.18,-999.98 1159.18,-990.71 1159.18,-982.11"/>
+<polygon fill="black" stroke="black" points="1162.68,-982.1 1159.18,-972.1 1155.68,-982.1 1162.68,-982.1"/>
+</g>
+<!-- 141&#45;&gt;142 -->
+<g id="edge111" class="edge">
+<title>141&#45;&gt;142</title>
+<path fill="none" stroke="black" d="M1132.27,-935.88C1118.02,-926.81 1100.33,-915.55 1084.94,-905.76"/>
+<polygon fill="black" stroke="black" points="1086.5,-902.61 1076.19,-900.19 1082.75,-908.51 1086.5,-902.61"/>
+</g>
+<!-- 143 -->
+<g id="node109" class="node">
+<title>143</title>
+<polygon fill="none" stroke="black" points="1117.68,-828 790.68,-828 790.68,-792 1117.68,-792 1117.68,-828"/>
+<text text-anchor="middle" x="954.18" y="-806.3" font-family="Times,serif" font-size="14.00">mean(·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 142&#45;&gt;143 -->
+<g id="edge113" class="edge">
+<title>142&#45;&gt;143</title>
+<path fill="none" stroke="black" d="M1025.69,-863.7C1013.61,-854.8 998.72,-843.82 985.66,-834.2"/>
+<polygon fill="black" stroke="black" points="987.52,-831.22 977.39,-828.1 983.36,-836.85 987.52,-831.22"/>
+</g>
+<!-- 144 -->
+<g id="node110" class="node">
+<title>144</title>
+<polygon fill="none" stroke="black" points="975.68,-756 870.68,-756 870.68,-720 975.68,-720 975.68,-756"/>
+<text text-anchor="middle" x="923.18" y="-734.3" font-family="Times,serif" font-size="14.00">subtract(·, ·)</text>
+</g>
+<!-- 142&#45;&gt;144 -->
+<g id="edge114" class="edge">
+<title>142&#45;&gt;144</title>
+<path fill="none" stroke="black" d="M1013.5,-879.36C947.45,-875.5 810.19,-863.32 781.18,-828 771.02,-815.64 772.05,-805.14 781.18,-792 799.1,-766.19 831.2,-752.92 860.35,-746.12"/>
+<polygon fill="black" stroke="black" points="861.45,-749.46 870.51,-743.97 860,-742.61 861.45,-749.46"/>
+</g>
+<!-- 145 -->
+<g id="node111" class="node">
+<title>145</title>
+<polygon fill="none" stroke="black" points="1499.68,-828 1172.68,-828 1172.68,-792 1499.68,-792 1499.68,-828"/>
+<text text-anchor="middle" x="1336.18" y="-806.3" font-family="Times,serif" font-size="14.00">mean(·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 142&#45;&gt;145 -->
+<g id="edge116" class="edge">
+<title>142&#45;&gt;145</title>
+<path fill="none" stroke="black" d="M1084.94,-872.28C1127.61,-861.87 1200.23,-844.16 1256.45,-830.45"/>
+<polygon fill="black" stroke="black" points="1257.5,-833.79 1266.39,-828.02 1255.84,-826.99 1257.5,-833.79"/>
+</g>
+<!-- 146 -->
+<g id="node112" class="node">
+<title>146</title>
+<polygon fill="none" stroke="black" points="1356.18,-756 994.18,-756 994.18,-720 1356.18,-720 1356.18,-756"/>
+<text text-anchor="middle" x="1175.18" y="-734.3" font-family="Times,serif" font-size="14.00">variance(·, ·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 142&#45;&gt;146 -->
+<g id="edge117" class="edge">
+<title>142&#45;&gt;146</title>
+<path fill="none" stroke="black" d="M1080.98,-863.78C1095.96,-854.57 1113.33,-842.2 1126.18,-828 1142.91,-809.51 1156.1,-784.31 1164.56,-765.44"/>
+<polygon fill="black" stroke="black" points="1167.79,-766.78 1168.55,-756.22 1161.37,-764.01 1167.79,-766.78"/>
+</g>
+<!-- 143&#45;&gt;144 -->
+<g id="edge115" class="edge">
+<title>143&#45;&gt;144</title>
+<path fill="none" stroke="black" d="M946.51,-791.7C942.98,-783.73 938.72,-774.1 934.81,-765.26"/>
+<polygon fill="black" stroke="black" points="938,-763.83 930.75,-756.1 931.6,-766.67 938,-763.83"/>
+</g>
+<!-- 150 -->
+<g id="node115" class="node">
+<title>150</title>
+<polygon fill="none" stroke="black" points="1114.68,-540 1025.68,-540 1025.68,-504 1114.68,-504 1114.68,-540"/>
+<text text-anchor="middle" x="1070.18" y="-518.3" font-family="Times,serif" font-size="14.00">divide(·, ·)</text>
+</g>
+<!-- 144&#45;&gt;150 -->
+<g id="edge121" class="edge">
+<title>144&#45;&gt;150</title>
+<path fill="none" stroke="black" d="M934.96,-719.85C960.71,-682.36 1022.13,-592.95 1052.52,-548.7"/>
+<polygon fill="black" stroke="black" points="1055.56,-550.46 1058.34,-540.23 1049.79,-546.49 1055.56,-550.46"/>
+</g>
+<!-- 145&#45;&gt;146 -->
+<g id="edge118" class="edge">
+<title>145&#45;&gt;146</title>
+<path fill="none" stroke="black" d="M1296.79,-791.88C1274.78,-782.31 1247.17,-770.3 1223.82,-760.15"/>
+<polygon fill="black" stroke="black" points="1225,-756.85 1214.43,-756.07 1222.21,-763.27 1225,-756.85"/>
+</g>
+<!-- 147 -->
+<g id="node113" class="node">
+<title>147</title>
+<polygon fill="none" stroke="black" points="1183.68,-684 1122.68,-684 1122.68,-648 1183.68,-648 1183.68,-684"/>
+<text text-anchor="middle" x="1153.18" y="-662.3" font-family="Times,serif" font-size="14.00">sqrt(·)</text>
+</g>
+<!-- 146&#45;&gt;147 -->
+<g id="edge119" class="edge">
+<title>146&#45;&gt;147</title>
+<path fill="none" stroke="black" d="M1169.74,-719.7C1167.29,-711.9 1164.34,-702.51 1161.61,-693.83"/>
+<polygon fill="black" stroke="black" points="1164.89,-692.59 1158.55,-684.1 1158.21,-694.69 1164.89,-692.59"/>
+</g>
+<!-- 149 -->
+<g id="node114" class="node">
+<title>149</title>
+<polygon fill="none" stroke="black" points="1185.68,-612 1078.68,-612 1078.68,-576 1185.68,-576 1185.68,-612"/>
+<text text-anchor="middle" x="1132.18" y="-590.3" font-family="Times,serif" font-size="14.00">add(·, 1e&#45;12)</text>
+</g>
+<!-- 147&#45;&gt;149 -->
+<g id="edge120" class="edge">
+<title>147&#45;&gt;149</title>
+<path fill="none" stroke="black" d="M1147.99,-647.7C1145.65,-639.9 1142.83,-630.51 1140.23,-621.83"/>
+<polygon fill="black" stroke="black" points="1143.53,-620.68 1137.31,-612.1 1136.83,-622.69 1143.53,-620.68"/>
+</g>
+<!-- 149&#45;&gt;150 -->
+<g id="edge122" class="edge">
+<title>149&#45;&gt;150</title>
+<path fill="none" stroke="black" d="M1116.85,-575.7C1109.42,-567.3 1100.35,-557.07 1092.2,-547.86"/>
+<polygon fill="black" stroke="black" points="1094.58,-545.27 1085.33,-540.1 1089.34,-549.91 1094.58,-545.27"/>
+</g>
+<!-- 150&#45;&gt;151 -->
+<g id="edge123" class="edge">
+<title>150&#45;&gt;151</title>
+<path fill="none" stroke="black" d="M1032.99,-503.88C1012.31,-494.35 986.37,-482.41 964.39,-472.28"/>
+<polygon fill="black" stroke="black" points="965.79,-469.07 955.24,-468.07 962.86,-475.43 965.79,-469.07"/>
+</g>
+<!-- 151&#45;&gt;152 -->
+<g id="edge125" class="edge">
+<title>151&#45;&gt;152</title>
+<path fill="none" stroke="black" d="M881.48,-431.88C860.76,-422.21 834.71,-410.05 812.82,-399.83"/>
+<polygon fill="black" stroke="black" points="814.26,-396.64 803.72,-395.59 811.3,-402.99 814.26,-396.64"/>
+</g>
+<!-- 153 -->
+<g id="node118" class="node">
+<title>153</title>
+<polygon fill="none" stroke="black" points="811.18,-324 725.18,-324 725.18,-288 811.18,-288 811.18,-324"/>
+<text text-anchor="middle" x="768.18" y="-302.3" font-family="Times,serif" font-size="14.00">Tuple[...])</text>
+</g>
+<!-- 152&#45;&gt;153 -->
+<g id="edge127" class="edge">
+<title>152&#45;&gt;153</title>
+<path fill="none" stroke="black" d="M768.18,-359.7C768.18,-351.98 768.18,-342.71 768.18,-334.11"/>
+<polygon fill="black" stroke="black" points="771.68,-334.1 768.18,-324.1 764.68,-334.1 771.68,-334.1"/>
+</g>
+<!-- 154 -->
+<g id="node119" class="node">
+<title>154</title>
+<polygon fill="none" stroke="black" points="852.18,-252 684.18,-252 684.18,-216 852.18,-216 852.18,-252"/>
+<text text-anchor="middle" x="768.18" y="-230.3" font-family="Times,serif" font-size="14.00">TupleGetItem(idx=0)</text>
+</g>
+<!-- 153&#45;&gt;154 -->
+<g id="edge128" class="edge">
+<title>153&#45;&gt;154</title>
+<path fill="none" stroke="black" d="M768.18,-287.7C768.18,-279.98 768.18,-270.71 768.18,-262.11"/>
+<polygon fill="black" stroke="black" points="771.68,-262.1 768.18,-252.1 764.68,-262.1 771.68,-262.1"/>
+</g>
+<!-- 154&#45;&gt;155 -->
+<g id="edge129" class="edge">
+<title>154&#45;&gt;155</title>
+<path fill="none" stroke="black" d="M733.19,-215.88C713.98,-206.47 689.95,-194.71 669.44,-184.67"/>
+<polygon fill="black" stroke="black" points="670.81,-181.44 660.29,-180.19 667.74,-187.73 670.81,-181.44"/>
+</g>
+<!-- 156 -->
+<g id="node121" class="node">
+<title>156</title>
+<polygon fill="none" stroke="black" points="791.18,-108 459.18,-108 459.18,-72 791.18,-72 791.18,-108"/>
+<text text-anchor="middle" x="625.18" y="-86.3" font-family="Times,serif" font-size="14.00">sum(·| axis=None, keepdims=0, exclude=0)</text>
+</g>
+<!-- 155&#45;&gt;156 -->
+<g id="edge131" class="edge">
+<title>155&#45;&gt;156</title>
+<path fill="none" stroke="black" d="M625.18,-143.7C625.18,-135.98 625.18,-126.71 625.18,-118.11"/>
+<polygon fill="black" stroke="black" points="628.68,-118.1 625.18,-108.1 621.68,-118.1 628.68,-118.1"/>
+</g>
+<!-- 157 -->
+<g id="node122" class="node">
+<title>157</title>
+<polygon fill="none" stroke="black" points="665.18,-36 585.18,-36 585.18,0 665.18,0 665.18,-36"/>
+<text text-anchor="middle" x="625.18" y="-14.3" font-family="Times,serif" font-size="14.00">Function</text>
+</g>
+<!-- 156&#45;&gt;157 -->
+<g id="edge132" class="edge">
+<title>156&#45;&gt;157</title>
+<path fill="none" stroke="black" d="M625.18,-71.7C625.18,-63.98 625.18,-54.71 625.18,-46.11"/>
+<polygon fill="black" stroke="black" points="628.68,-46.1 625.18,-36.1 621.68,-46.1 628.68,-46.1"/>
+</g>
+</g>
+</svg>
diff --git a/images/bert-pytorch/pytorch-tvm-training_31_0.svg b/images/bert-pytorch/pytorch-tvm-training_31_0.svg
new file mode 100644
index 0000000..c3a3472
--- /dev/null
+++ b/images/bert-pytorch/pytorch-tvm-training_31_0.svg
@@ -0,0 +1,4015 @@
+<?xml version="1.0" encoding="UTF-8" standalone="no"?>
+<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
+ "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
+<!-- Generated by graphviz version 2.43.0 (0)
+ -->
+<!-- Title: %3 Pages: 1 -->
+<svg width="6041pt" height="8756pt"
+ viewBox="0.00 0.00 6041.28 8756.00" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
+<g id="graph0" class="graph" transform="scale(1 1) rotate(0) translate(4 8752)">
+<title>%3</title>
+<polygon fill="white" stroke="transparent" points="-4,4 -4,-8752 6037.28,-8752 6037.28,4 -4,4"/>
+<!-- 0 -->
+<g id="node1" class="node">
+<title>0</title>
+<ellipse fill="none" stroke="black" cx="3735.78" cy="-8586" rx="170.87" ry="18"/>
+<text text-anchor="middle" x="3735.78" y="-8582.3" font-family="Times,serif" font-size="14.00">input: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 32 -->
+<g id="node23" class="node">
+<title>32</title>
+<polygon fill="none" stroke="black" points="4397.78,-8532 3949.78,-8532 3949.78,-8496 4397.78,-8496 4397.78,-8532"/>
+<text text-anchor="middle" x="4173.78" y="-8510.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 0&#45;&gt;32 -->
+<g id="edge1" class="edge">
+<title>0&#45;&gt;32</title>
+<path fill="none" stroke="black" d="M3825.47,-8570.67C3892.4,-8559.97 3984.44,-8545.26 4057.09,-8533.65"/>
+<polygon fill="black" stroke="black" points="4057.89,-8537.07 4067.21,-8532.03 4056.78,-8530.15 4057.89,-8537.07"/>
+</g>
+<!-- 95 -->
+<g id="node75" class="node">
+<title>95</title>
+<polygon fill="none" stroke="black" points="3562.28,-6732 3491.28,-6732 3491.28,-6696 3562.28,-6696 3562.28,-6732"/>
+<text text-anchor="middle" x="3526.78" y="-6710.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 0&#45;&gt;95 -->
+<g id="edge67" class="edge">
+<title>0&#45;&gt;95</title>
+<path fill="none" stroke="black" d="M3651.14,-8570.3C3541.09,-8548.73 3362.78,-8504.32 3362.78,-8443 3362.78,-8443 3362.78,-8443 3362.78,-6857 3362.78,-6793.8 3433.25,-6752.44 3481.8,-6731.44"/>
+<polygon fill="black" stroke="black" points="3483.32,-6734.6 3491.19,-6727.51 3480.62,-6728.14 3483.32,-6734.6"/>
+</g>
+<!-- 1 -->
+<g id="node2" class="node">
+<title>1</title>
+<ellipse fill="none" stroke="black" cx="5402.78" cy="-7794" rx="217.96" ry="18"/>
+<text text-anchor="middle" x="5402.78" y="-7790.3" font-family="Times,serif" font-size="14.00">attention_mask: Tensor[(1, 1, 1, 14), float32]</text>
+</g>
+<!-- 62 -->
+<g id="node46" class="node">
+<title>62</title>
+<polygon fill="none" stroke="black" points="5438.28,-7740 5367.28,-7740 5367.28,-7704 5438.28,-7704 5438.28,-7740"/>
+<text text-anchor="middle" x="5402.78" y="-7718.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 1&#45;&gt;62 -->
+<g id="edge30" class="edge">
+<title>1&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M5402.78,-7775.7C5402.78,-7767.98 5402.78,-7758.71 5402.78,-7750.11"/>
+<polygon fill="black" stroke="black" points="5406.28,-7750.1 5402.78,-7740.1 5399.28,-7750.1 5406.28,-7750.1"/>
+</g>
+<!-- 2 -->
+<g id="node3" class="node">
+<title>2</title>
+<ellipse fill="none" stroke="black" cx="5413.78" cy="-8658" rx="265.65" ry="18"/>
+<text text-anchor="middle" x="5413.78" y="-8654.3" font-family="Times,serif" font-size="14.00">attention.self.query.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 33 -->
+<g id="node24" class="node">
+<title>33</title>
+<polygon fill="none" stroke="black" points="5510.28,-8604 5317.28,-8604 5317.28,-8568 5510.28,-8568 5510.28,-8604"/>
+<text text-anchor="middle" x="5413.78" y="-8582.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 2&#45;&gt;33 -->
+<g id="edge2" class="edge">
+<title>2&#45;&gt;33</title>
+<path fill="none" stroke="black" d="M5413.78,-8639.7C5413.78,-8631.98 5413.78,-8622.71 5413.78,-8614.11"/>
+<polygon fill="black" stroke="black" points="5417.28,-8614.1 5413.78,-8604.1 5410.28,-8614.1 5417.28,-8614.1"/>
+</g>
+<!-- 3 -->
+<g id="node4" class="node">
+<title>3</title>
+<ellipse fill="none" stroke="black" cx="4434.78" cy="-8298" rx="232.86" ry="18"/>
+<text text-anchor="middle" x="4434.78" y="-8294.3" font-family="Times,serif" font-size="14.00">attention.self.query.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 40 -->
+<g id="node29" class="node">
+<title>40</title>
+<polygon fill="none" stroke="black" points="4942.28,-8244 4871.28,-8244 4871.28,-8208 4942.28,-8208 4942.28,-8244"/>
+<text text-anchor="middle" x="4906.78" y="-8222.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 3&#45;&gt;40 -->
+<g id="edge9" class="edge">
+<title>3&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M4537.2,-8281.81C4637.8,-8266.89 4786.69,-8244.81 4861.05,-8233.78"/>
+<polygon fill="black" stroke="black" points="4861.65,-8237.23 4871.03,-8232.3 4860.62,-8230.31 4861.65,-8237.23"/>
+</g>
+<!-- 4 -->
+<g id="node5" class="node">
+<title>4</title>
+<ellipse fill="none" stroke="black" cx="4970.78" cy="-8730" rx="254.55" ry="18"/>
+<text text-anchor="middle" x="4970.78" y="-8726.3" font-family="Times,serif" font-size="14.00">attention.self.key.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 46 -->
+<g id="node33" class="node">
+<title>46</title>
+<polygon fill="none" stroke="black" points="5067.28,-8676 4874.28,-8676 4874.28,-8640 5067.28,-8640 5067.28,-8676"/>
+<text text-anchor="middle" x="4970.78" y="-8654.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 4&#45;&gt;46 -->
+<g id="edge13" class="edge">
+<title>4&#45;&gt;46</title>
+<path fill="none" stroke="black" d="M4970.78,-8711.7C4970.78,-8703.98 4970.78,-8694.71 4970.78,-8686.11"/>
+<polygon fill="black" stroke="black" points="4974.28,-8686.1 4970.78,-8676.1 4967.28,-8686.1 4974.28,-8686.1"/>
+</g>
+<!-- 5 -->
+<g id="node6" class="node">
+<title>5</title>
+<ellipse fill="none" stroke="black" cx="5409.78" cy="-8370" rx="221.76" ry="18"/>
+<text text-anchor="middle" x="5409.78" y="-8366.3" font-family="Times,serif" font-size="14.00">attention.self.key.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 51 -->
+<g id="node38" class="node">
+<title>51</title>
+<polygon fill="none" stroke="black" points="5445.28,-8316 5374.28,-8316 5374.28,-8280 5445.28,-8280 5445.28,-8316"/>
+<text text-anchor="middle" x="5409.78" y="-8294.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 5&#45;&gt;51 -->
+<g id="edge20" class="edge">
+<title>5&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M5409.78,-8351.7C5409.78,-8343.98 5409.78,-8334.71 5409.78,-8326.11"/>
+<polygon fill="black" stroke="black" points="5413.28,-8326.1 5409.78,-8316.1 5406.28,-8326.1 5413.28,-8326.1"/>
+</g>
+<!-- 6 -->
+<g id="node7" class="node">
+<title>6</title>
+<ellipse fill="none" stroke="black" cx="3880.78" cy="-8226" rx="265.35" ry="18"/>
+<text text-anchor="middle" x="3880.78" y="-8222.3" font-family="Times,serif" font-size="14.00">attention.self.value.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 69 -->
+<g id="node51" class="node">
+<title>69</title>
+<polygon fill="none" stroke="black" points="3985.28,-8172 3792.28,-8172 3792.28,-8136 3985.28,-8136 3985.28,-8172"/>
+<text text-anchor="middle" x="3888.78" y="-8150.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 6&#45;&gt;69 -->
+<g id="edge36" class="edge">
+<title>6&#45;&gt;69</title>
+<path fill="none" stroke="black" d="M3882.76,-8207.7C3883.64,-8199.98 3884.7,-8190.71 3885.68,-8182.11"/>
+<polygon fill="black" stroke="black" points="3889.17,-8182.44 3886.82,-8172.1 3882.21,-8181.64 3889.17,-8182.44"/>
+</g>
+<!-- 7 -->
+<g id="node8" class="node">
+<title>7</title>
+<ellipse fill="none" stroke="black" cx="4762.78" cy="-7866" rx="232.06" ry="18"/>
+<text text-anchor="middle" x="4762.78" y="-7862.3" font-family="Times,serif" font-size="14.00">attention.self.value.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 74 -->
+<g id="node56" class="node">
+<title>74</title>
+<polygon fill="none" stroke="black" points="4414.28,-7812 4343.28,-7812 4343.28,-7776 4414.28,-7776 4414.28,-7812"/>
+<text text-anchor="middle" x="4378.78" y="-7790.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 7&#45;&gt;74 -->
+<g id="edge43" class="edge">
+<title>7&#45;&gt;74</title>
+<path fill="none" stroke="black" d="M4676.12,-7849.2C4597.97,-7834.96 4486.08,-7814.56 4424.23,-7803.29"/>
+<polygon fill="black" stroke="black" points="4424.85,-7799.84 4414.38,-7801.49 4423.59,-7806.73 4424.85,-7799.84"/>
+</g>
+<!-- 8 -->
+<g id="node9" class="node">
+<title>8</title>
+<ellipse fill="none" stroke="black" cx="4346.78" cy="-7290" rx="282.15" ry="18"/>
+<text text-anchor="middle" x="4346.78" y="-7286.3" font-family="Times,serif" font-size="14.00">attention.output.dense.weight: Tensor[(768, 768), float32]</text>
+</g>
+<!-- 86 -->
+<g id="node67" class="node">
+<title>86</title>
+<polygon fill="none" stroke="black" points="4456.28,-7236 4263.28,-7236 4263.28,-7200 4456.28,-7200 4456.28,-7236"/>
+<text text-anchor="middle" x="4359.78" y="-7214.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 8&#45;&gt;86 -->
+<g id="edge55" class="edge">
+<title>8&#45;&gt;86</title>
+<path fill="none" stroke="black" d="M4349.99,-7271.7C4351.43,-7263.98 4353.15,-7254.71 4354.74,-7246.11"/>
+<polygon fill="black" stroke="black" points="4358.22,-7246.58 4356.6,-7236.1 4351.34,-7245.3 4358.22,-7246.58"/>
+</g>
+<!-- 9 -->
+<g id="node10" class="node">
+<title>9</title>
+<ellipse fill="none" stroke="black" cx="3639.78" cy="-6930" rx="248.86" ry="18"/>
+<text text-anchor="middle" x="3639.78" y="-6926.3" font-family="Times,serif" font-size="14.00">attention.output.dense.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 91 -->
+<g id="node72" class="node">
+<title>91</title>
+<polygon fill="none" stroke="black" points="3675.28,-6876 3604.28,-6876 3604.28,-6840 3675.28,-6840 3675.28,-6876"/>
+<text text-anchor="middle" x="3639.78" y="-6854.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 9&#45;&gt;91 -->
+<g id="edge62" class="edge">
+<title>9&#45;&gt;91</title>
+<path fill="none" stroke="black" d="M3639.78,-6911.7C3639.78,-6903.98 3639.78,-6894.71 3639.78,-6886.11"/>
+<polygon fill="black" stroke="black" points="3643.28,-6886.1 3639.78,-6876.1 3636.28,-6886.1 3643.28,-6886.1"/>
+</g>
+<!-- 10 -->
+<g id="node11" class="node">
+<title>10</title>
+<ellipse fill="none" stroke="black" cx="3344.78" cy="-6354" rx="286.75" ry="18"/>
+<text text-anchor="middle" x="3344.78" y="-6350.3" font-family="Times,serif" font-size="14.00">attention.output.LayerNorm.weight: Tensor[(768,), float32]</text>
+</g>
+<!-- 106 -->
+<g id="node82" class="node">
+<title>106</title>
+<polygon fill="none" stroke="black" points="3285.78,-6300 3181.78,-6300 3181.78,-6264 3285.78,-6264 3285.78,-6300"/>
+<text text-anchor="middle" x="3233.78" y="-6278.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;106 -->
+<g id="edge78" class="edge">
+<title>10&#45;&gt;106</title>
+<path fill="none" stroke="black" d="M3317.91,-6336.05C3303.29,-6326.84 3285.03,-6315.32 3269.25,-6305.37"/>
+<polygon fill="black" stroke="black" points="3271.11,-6302.41 3260.79,-6300.03 3267.38,-6308.33 3271.11,-6302.41"/>
+</g>
+<!-- 225 -->
+<g id="node169" class="node">
+<title>225</title>
+<polygon fill="none" stroke="black" points="3647.78,-3060 3543.78,-3060 3543.78,-3024 3647.78,-3024 3647.78,-3060"/>
+<text text-anchor="middle" x="3595.78" y="-3038.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 10&#45;&gt;225 -->
+<g id="edge199" class="edge">
+<title>10&#45;&gt;225</title>
+<path fill="none" stroke="black" d="M3467.89,-6337.65C3541.87,-6320.76 3621.78,-6285.31 3621.78,-6211 3621.78,-6211 3621.78,-6211 3621.78,-3185 3621.78,-3144.55 3611.45,-3098.67 3603.81,-3070.28"/>
+<polygon fill="black" stroke="black" points="3607.09,-3069 3601.04,-3060.3 3600.34,-3070.87 3607.09,-3069"/>
+</g>
+<!-- 11 -->
+<g id="node12" class="node">
+<title>11</title>
+<ellipse fill="none" stroke="black" cx="2270.78" cy="-6282" rx="274.05" ry="18"/>
+<text text-anchor="middle" x="2270.78" y="-6278.3" font-family="Times,serif" font-size="14.00">attention.output.LayerNorm.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 107 -->
+<g id="node83" class="node">
+<title>107</title>
+<polygon fill="none" stroke="black" points="2373.28,-6228 2302.28,-6228 2302.28,-6192 2373.28,-6192 2373.28,-6228"/>
+<text text-anchor="middle" x="2337.78" y="-6206.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 11&#45;&gt;107 -->
+<g id="edge80" class="edge">
+<title>11&#45;&gt;107</title>
+<path fill="none" stroke="black" d="M2287.34,-6263.7C2295.46,-6255.22 2305.37,-6244.86 2314.25,-6235.58"/>
+<polygon fill="black" stroke="black" points="2317.02,-6237.75 2321.41,-6228.1 2311.97,-6232.91 2317.02,-6237.75"/>
+</g>
+<!-- 12 -->
+<g id="node13" class="node">
+<title>12</title>
+<ellipse fill="none" stroke="black" cx="2584.78" cy="-6354" rx="271.85" ry="18"/>
+<text text-anchor="middle" x="2584.78" y="-6350.3" font-family="Times,serif" font-size="14.00">intermediate.dense.weight: Tensor[(3072, 768), float32]</text>
+</g>
+<!-- 109 -->
+<g id="node85" class="node">
+<title>109</title>
+<polygon fill="none" stroke="black" points="2756.28,-6300 2563.28,-6300 2563.28,-6264 2756.28,-6264 2756.28,-6300"/>
+<text text-anchor="middle" x="2659.78" y="-6278.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 12&#45;&gt;109 -->
+<g id="edge82" class="edge">
+<title>12&#45;&gt;109</title>
+<path fill="none" stroke="black" d="M2603.32,-6335.7C2612.58,-6327.05 2623.94,-6316.45 2634.04,-6307.03"/>
+<polygon fill="black" stroke="black" points="2636.53,-6309.49 2641.45,-6300.1 2631.75,-6304.37 2636.53,-6309.49"/>
+</g>
+<!-- 13 -->
+<g id="node14" class="node">
+<title>13</title>
+<ellipse fill="none" stroke="black" cx="1857.78" cy="-5994" rx="238.56" ry="18"/>
+<text text-anchor="middle" x="1857.78" y="-5990.3" font-family="Times,serif" font-size="14.00">intermediate.dense.bias: Tensor[(3072,), float32]</text>
+</g>
+<!-- 116 -->
+<g id="node90" class="node">
+<title>116</title>
+<polygon fill="none" stroke="black" points="1893.28,-5940 1822.28,-5940 1822.28,-5904 1893.28,-5904 1893.28,-5940"/>
+<text text-anchor="middle" x="1857.78" y="-5918.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 13&#45;&gt;116 -->
+<g id="edge89" class="edge">
+<title>13&#45;&gt;116</title>
+<path fill="none" stroke="black" d="M1857.78,-5975.7C1857.78,-5967.98 1857.78,-5958.71 1857.78,-5950.11"/>
+<polygon fill="black" stroke="black" points="1861.28,-5950.1 1857.78,-5940.1 1854.28,-5950.1 1861.28,-5950.1"/>
+</g>
+<!-- 14 -->
+<g id="node15" class="node">
+<title>14</title>
+<ellipse fill="none" stroke="black" cx="2244.78" cy="-5706" rx="242.36" ry="18"/>
+<text text-anchor="middle" x="2244.78" y="-5702.3" font-family="Times,serif" font-size="14.00">output.dense.weight: Tensor[(768, 3072), float32]</text>
+</g>
+<!-- 128 -->
+<g id="node97" class="node">
+<title>128</title>
+<polygon fill="none" stroke="black" points="2341.28,-5652 2148.28,-5652 2148.28,-5616 2341.28,-5616 2341.28,-5652"/>
+<text text-anchor="middle" x="2244.78" y="-5630.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[1, 0])</text>
+</g>
+<!-- 14&#45;&gt;128 -->
+<g id="edge97" class="edge">
+<title>14&#45;&gt;128</title>
+<path fill="none" stroke="black" d="M2244.78,-5687.7C2244.78,-5679.98 2244.78,-5670.71 2244.78,-5662.11"/>
+<polygon fill="black" stroke="black" points="2248.28,-5662.1 2244.78,-5652.1 2241.28,-5662.1 2248.28,-5662.1"/>
+</g>
+<!-- 15 -->
+<g id="node16" class="node">
+<title>15</title>
+<ellipse fill="none" stroke="black" cx="1764.78" cy="-5346" rx="203.36" ry="18"/>
+<text text-anchor="middle" x="1764.78" y="-5342.3" font-family="Times,serif" font-size="14.00">output.dense.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 134 -->
+<g id="node102" class="node">
+<title>134</title>
+<polygon fill="none" stroke="black" points="1930.28,-5292 1859.28,-5292 1859.28,-5256 1930.28,-5256 1930.28,-5292"/>
+<text text-anchor="middle" x="1894.78" y="-5270.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 15&#45;&gt;134 -->
+<g id="edge104" class="edge">
+<title>15&#45;&gt;134</title>
+<path fill="none" stroke="black" d="M1796.25,-5328.05C1813.68,-5318.67 1835.55,-5306.89 1854.26,-5296.82"/>
+<polygon fill="black" stroke="black" points="1856,-5299.85 1863.15,-5292.03 1852.69,-5293.69 1856,-5299.85"/>
+</g>
+<!-- 16 -->
+<g id="node17" class="node">
+<title>16</title>
+<ellipse fill="none" stroke="black" cx="992.78" cy="-4842" rx="241.26" ry="18"/>
+<text text-anchor="middle" x="992.78" y="-4838.3" font-family="Times,serif" font-size="14.00">output.LayerNorm.weight: Tensor[(768,), float32]</text>
+</g>
+<!-- 146 -->
+<g id="node112" class="node">
+<title>146</title>
+<polygon fill="none" stroke="black" points="1044.78,-4716 940.78,-4716 940.78,-4680 1044.78,-4680 1044.78,-4716"/>
+<text text-anchor="middle" x="992.78" y="-4694.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;146 -->
+<g id="edge120" class="edge">
+<title>16&#45;&gt;146</title>
+<path fill="none" stroke="black" d="M992.78,-4823.87C992.78,-4799.67 992.78,-4755.21 992.78,-4726.39"/>
+<polygon fill="black" stroke="black" points="996.28,-4726.19 992.78,-4716.19 989.28,-4726.19 996.28,-4726.19"/>
+</g>
+<!-- 170 -->
+<g id="node127" class="node">
+<title>170</title>
+<polygon fill="none" stroke="black" points="1467.78,-4788 1363.78,-4788 1363.78,-4752 1467.78,-4752 1467.78,-4788"/>
+<text text-anchor="middle" x="1415.78" y="-4766.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 16&#45;&gt;170 -->
+<g id="edge137" class="edge">
+<title>16&#45;&gt;170</title>
+<path fill="none" stroke="black" d="M1086.92,-4825.42C1168.25,-4811.96 1283.45,-4792.9 1353.7,-4781.27"/>
+<polygon fill="black" stroke="black" points="1354.44,-4784.7 1363.73,-4779.61 1353.29,-4777.79 1354.44,-4784.7"/>
+</g>
+<!-- 17 -->
+<g id="node18" class="node">
+<title>17</title>
+<ellipse fill="none" stroke="black" cx="228.78" cy="-4698" rx="228.56" ry="18"/>
+<text text-anchor="middle" x="228.78" y="-4694.3" font-family="Times,serif" font-size="14.00">output.LayerNorm.bias: Tensor[(768,), float32]</text>
+</g>
+<!-- 147 -->
+<g id="node113" class="node">
+<title>147</title>
+<polygon fill="none" stroke="black" points="457.28,-4644 386.28,-4644 386.28,-4608 457.28,-4608 457.28,-4644"/>
+<text text-anchor="middle" x="421.78" y="-4622.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 17&#45;&gt;147 -->
+<g id="edge122" class="edge">
+<title>17&#45;&gt;147</title>
+<path fill="none" stroke="black" d="M275.01,-4680.23C305.8,-4669.07 346.07,-4654.46 376.63,-4643.37"/>
+<polygon fill="black" stroke="black" points="377.99,-4646.61 386.19,-4639.91 375.6,-4640.03 377.99,-4646.61"/>
+</g>
+<!-- 18 -->
+<g id="node19" class="node">
+<title>18</title>
+<ellipse fill="none" stroke="black" cx="839.78" cy="-5130" rx="183.87" ry="18"/>
+<text text-anchor="middle" x="839.78" y="-5126.3" font-family="Times,serif" font-size="14.00">gr:out:0: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 166 -->
+<g id="node123" class="node">
+<title>166</title>
+<polygon fill="none" stroke="black" points="1015.78,-5076 911.78,-5076 911.78,-5040 1015.78,-5040 1015.78,-5076"/>
+<text text-anchor="middle" x="963.78" y="-5054.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 18&#45;&gt;166 -->
+<g id="edge131" class="edge">
+<title>18&#45;&gt;166</title>
+<path fill="none" stroke="black" d="M869.48,-5112.23C885.95,-5102.93 906.65,-5091.25 924.45,-5081.2"/>
+<polygon fill="black" stroke="black" points="926.47,-5084.08 933.46,-5076.12 923.03,-5077.98 926.47,-5084.08"/>
+</g>
+<!-- 19 -->
+<g id="node20" class="node">
+<title>19</title>
+<ellipse fill="none" stroke="black" cx="4790.78" cy="-7722" rx="204.16" ry="18"/>
+<text text-anchor="middle" x="4790.78" y="-7718.3" font-family="Times,serif" font-size="14.00">dropout:0: Tensor[(1, 12, 14, 14), float32]</text>
+</g>
+<!-- 65 -->
+<g id="node48" class="node">
+<title>65</title>
+<polygon fill="none" stroke="black" points="4878.78,-7668 4702.78,-7668 4702.78,-7632 4878.78,-7632 4878.78,-7668"/>
+<text text-anchor="middle" x="4790.78" y="-7646.3" font-family="Times,serif" font-size="14.00">multiply(·, 1.1111112)</text>
+</g>
+<!-- 19&#45;&gt;65 -->
+<g id="edge32" class="edge">
+<title>19&#45;&gt;65</title>
+<path fill="none" stroke="black" d="M4790.78,-7703.7C4790.78,-7695.98 4790.78,-7686.71 4790.78,-7678.11"/>
+<polygon fill="black" stroke="black" points="4794.28,-7678.1 4790.78,-7668.1 4787.28,-7678.1 4794.28,-7678.1"/>
+</g>
+<!-- 20 -->
+<g id="node21" class="node">
+<title>20</title>
+<ellipse fill="none" stroke="black" cx="3142.78" cy="-6930" rx="192.27" ry="18"/>
+<text text-anchor="middle" x="3142.78" y="-6926.3" font-family="Times,serif" font-size="14.00">dropout:1: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 93 -->
+<g id="node73" class="node">
+<title>93</title>
+<polygon fill="none" stroke="black" points="3230.78,-6876 3054.78,-6876 3054.78,-6840 3230.78,-6840 3230.78,-6876"/>
+<text text-anchor="middle" x="3142.78" y="-6854.3" font-family="Times,serif" font-size="14.00">multiply(·, 1.1111112)</text>
+</g>
+<!-- 20&#45;&gt;93 -->
+<g id="edge63" class="edge">
+<title>20&#45;&gt;93</title>
+<path fill="none" stroke="black" d="M3142.78,-6911.7C3142.78,-6903.98 3142.78,-6894.71 3142.78,-6886.11"/>
+<polygon fill="black" stroke="black" points="3146.28,-6886.1 3142.78,-6876.1 3139.28,-6886.1 3146.28,-6886.1"/>
+</g>
+<!-- 21 -->
+<g id="node22" class="node">
+<title>21</title>
+<ellipse fill="none" stroke="black" cx="1275.78" cy="-5346" rx="192.27" ry="18"/>
+<text text-anchor="middle" x="1275.78" y="-5342.3" font-family="Times,serif" font-size="14.00">dropout:2: Tensor[(1, 14, 768), float32]</text>
+</g>
+<!-- 136 -->
+<g id="node103" class="node">
+<title>136</title>
+<polygon fill="none" stroke="black" points="1363.78,-5292 1187.78,-5292 1187.78,-5256 1363.78,-5256 1363.78,-5292"/>
+<text text-anchor="middle" x="1275.78" y="-5270.3" font-family="Times,serif" font-size="14.00">multiply(·, 1.1111112)</text>
+</g>
+<!-- 21&#45;&gt;136 -->
+<g id="edge105" class="edge">
+<title>21&#45;&gt;136</title>
+<path fill="none" stroke="black" d="M1275.78,-5327.7C1275.78,-5319.98 1275.78,-5310.71 1275.78,-5302.11"/>
+<polygon fill="black" stroke="black" points="1279.28,-5302.1 1275.78,-5292.1 1272.28,-5302.1 1279.28,-5302.1"/>
+</g>
+<!-- 37 -->
+<g id="node27" class="node">
+<title>37</title>
+<polygon fill="none" stroke="black" points="4667.28,-8388 4498.28,-8388 4498.28,-8352 4667.28,-8352 4667.28,-8388"/>
+<text text-anchor="middle" x="4582.78" y="-8366.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 32&#45;&gt;37 -->
+<g id="edge5" class="edge">
+<title>32&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M4223.11,-8495.87C4298.76,-8469.61 4443.11,-8419.49 4524.04,-8391.39"/>
+<polygon fill="black" stroke="black" points="4525.29,-8394.66 4533.59,-8388.08 4522.99,-8388.05 4525.29,-8394.66"/>
+</g>
+<!-- 49 -->
+<g id="node36" class="node">
+<title>49</title>
+<polygon fill="none" stroke="black" points="4991.28,-8460 4822.28,-8460 4822.28,-8424 4991.28,-8424 4991.28,-8460"/>
+<text text-anchor="middle" x="4906.78" y="-8438.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 32&#45;&gt;49 -->
+<g id="edge16" class="edge">
+<title>32&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M4352.15,-8495.97C4496.95,-8482.14 4695.5,-8463.18 4811.56,-8452.09"/>
+<polygon fill="black" stroke="black" points="4812.22,-8455.55 4821.84,-8451.11 4811.56,-8448.58 4812.22,-8455.55"/>
+</g>
+<!-- 72 -->
+<g id="node54" class="node">
+<title>72</title>
+<polygon fill="none" stroke="black" points="4258.28,-7956 4089.28,-7956 4089.28,-7920 4258.28,-7920 4258.28,-7956"/>
+<text text-anchor="middle" x="4173.78" y="-7934.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 32&#45;&gt;72 -->
+<g id="edge39" class="edge">
+<title>32&#45;&gt;72</title>
+<path fill="none" stroke="black" d="M4173.78,-8495.95C4173.78,-8469.29 4173.78,-8416.11 4173.78,-8371 4173.78,-8371 4173.78,-8371 4173.78,-8081 4173.78,-8041 4173.78,-7994.65 4173.78,-7966.08"/>
+<polygon fill="black" stroke="black" points="4177.28,-7966.05 4173.78,-7956.05 4170.28,-7966.05 4177.28,-7966.05"/>
+</g>
+<!-- 323 -->
+<g id="node244" class="node">
+<title>323</title>
+<polygon fill="none" stroke="black" points="3444.28,-612 3233.28,-612 3233.28,-576 3444.28,-576 3444.28,-612"/>
+<text text-anchor="middle" x="3338.78" y="-590.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 32&#45;&gt;323 -->
+<g id="edge303" class="edge">
+<title>32&#45;&gt;323</title>
+<path fill="none" stroke="black" d="M3977.47,-8496C3748.46,-8473.45 3400.78,-8429.57 3400.78,-8371 3400.78,-8371 3400.78,-8371 3400.78,-7793 3400.78,-7365.61 3047.02,-7362.23 2941.78,-6948 2937.84,-6932.49 2941.01,-6927.98 2941.78,-6912 2946.01,-6824.14 2960.78,-6802.96 2960.78,-6715 2960.78,-6715 2960.78,-6715 2960.78,-6641 2960.78,-6315.99 3430.85,-6620.12 3640.78,-6372 3666.9,-6341.12 3659.78,-6323.45 3659.78,-6283 3659.78,-6283 3659.78,-6283 3659.78,-5201 3659.78,-4999.84 3404.73,-51 [...]
+<polygon fill="black" stroke="black" points="3322.86,-622.54 3325.63,-612.31 3317.09,-618.58 3322.86,-622.54"/>
+</g>
+<!-- 35 -->
+<g id="node25" class="node">
+<title>35</title>
+<polygon fill="none" stroke="black" points="5644.28,-8532 5183.28,-8532 5183.28,-8496 5644.28,-8496 5644.28,-8532"/>
+<text text-anchor="middle" x="5413.78" y="-8510.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 33&#45;&gt;35 -->
+<g id="edge3" class="edge">
+<title>33&#45;&gt;35</title>
+<path fill="none" stroke="black" d="M5413.78,-8567.7C5413.78,-8559.98 5413.78,-8550.71 5413.78,-8542.11"/>
+<polygon fill="black" stroke="black" points="5417.28,-8542.1 5413.78,-8532.1 5410.28,-8542.1 5417.28,-8542.1"/>
+</g>
+<!-- 36 -->
+<g id="node26" class="node">
+<title>36</title>
+<polygon fill="none" stroke="black" points="5519.28,-8460 5308.28,-8460 5308.28,-8424 5519.28,-8424 5519.28,-8460"/>
+<text text-anchor="middle" x="5413.78" y="-8438.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 35&#45;&gt;36 -->
+<g id="edge4" class="edge">
+<title>35&#45;&gt;36</title>
+<path fill="none" stroke="black" d="M5413.78,-8495.7C5413.78,-8487.98 5413.78,-8478.71 5413.78,-8470.11"/>
+<polygon fill="black" stroke="black" points="5417.28,-8470.1 5413.78,-8460.1 5410.28,-8470.1 5417.28,-8470.1"/>
+</g>
+<!-- 36&#45;&gt;37 -->
+<g id="edge6" class="edge">
+<title>36&#45;&gt;37</title>
+<path fill="none" stroke="black" d="M5308.17,-8436.96C5166.86,-8430.76 4908.34,-8416.63 4677.78,-8388.04"/>
+<polygon fill="black" stroke="black" points="4677.88,-8384.53 4667.52,-8386.76 4677.01,-8391.47 4677.88,-8384.53"/>
+</g>
+<!-- 316 -->
+<g id="node238" class="node">
+<title>316</title>
+<polygon fill="none" stroke="black" points="6033.28,-6804 5822.28,-6804 5822.28,-6768 6033.28,-6768 6033.28,-6804"/>
+<text text-anchor="middle" x="5927.78" y="-6782.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 36&#45;&gt;316 -->
+<g id="edge294" class="edge">
+<title>36&#45;&gt;316</title>
+<path fill="none" stroke="black" d="M5519.54,-8429.44C5654.25,-8411.72 5869.78,-8371.97 5869.78,-8299 5869.78,-8299 5869.78,-8299 5869.78,-6929 5869.78,-6885.7 5893.18,-6840.11 5910.26,-6812.61"/>
+<polygon fill="black" stroke="black" points="5913.32,-6814.32 5915.76,-6804.01 5907.42,-6810.55 5913.32,-6814.32"/>
+</g>
+<!-- 39 -->
+<g id="node28" class="node">
+<title>39</title>
+<polygon fill="none" stroke="black" points="5128.28,-8316 4685.28,-8316 4685.28,-8280 5128.28,-8280 5128.28,-8316"/>
+<text text-anchor="middle" x="4906.78" y="-8294.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 37&#45;&gt;39 -->
+<g id="edge7" class="edge">
+<title>37&#45;&gt;39</title>
+<path fill="none" stroke="black" d="M4661.62,-8351.97C4708.85,-8341.76 4769.02,-8328.76 4817.92,-8318.2"/>
+<polygon fill="black" stroke="black" points="4818.83,-8321.58 4827.86,-8316.05 4817.35,-8314.74 4818.83,-8321.58"/>
+</g>
+<!-- 39&#45;&gt;40 -->
+<g id="edge8" class="edge">
+<title>39&#45;&gt;40</title>
+<path fill="none" stroke="black" d="M4906.78,-8279.7C4906.78,-8271.98 4906.78,-8262.71 4906.78,-8254.11"/>
+<polygon fill="black" stroke="black" points="4910.28,-8254.1 4906.78,-8244.1 4903.28,-8254.1 4910.28,-8254.1"/>
+</g>
+<!-- 42 -->
+<g id="node30" class="node">
+<title>42</title>
+<polygon fill="none" stroke="black" points="5101.28,-8172 4722.28,-8172 4722.28,-8136 5101.28,-8136 5101.28,-8172"/>
+<text text-anchor="middle" x="4911.78" y="-8150.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 40&#45;&gt;42 -->
+<g id="edge10" class="edge">
+<title>40&#45;&gt;42</title>
+<path fill="none" stroke="black" d="M4908.02,-8207.7C4908.57,-8199.98 4909.23,-8190.71 4909.84,-8182.11"/>
+<polygon fill="black" stroke="black" points="4913.34,-8182.33 4910.56,-8172.1 4906.35,-8181.83 4913.34,-8182.33"/>
+</g>
+<!-- 43 -->
+<g id="node31" class="node">
+<title>43</title>
+<polygon fill="none" stroke="black" points="5029.28,-8100 4800.28,-8100 4800.28,-8064 5029.28,-8064 5029.28,-8100"/>
+<text text-anchor="middle" x="4914.78" y="-8078.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 42&#45;&gt;43 -->
+<g id="edge11" class="edge">
+<title>42&#45;&gt;43</title>
+<path fill="none" stroke="black" d="M4912.52,-8135.7C4912.85,-8127.98 4913.25,-8118.71 4913.62,-8110.11"/>
+<polygon fill="black" stroke="black" points="4917.12,-8110.25 4914.05,-8100.1 4910.12,-8109.95 4917.12,-8110.25"/>
+</g>
+<!-- 45 -->
+<g id="node32" class="node">
+<title>45</title>
+<polygon fill="none" stroke="black" points="5128.28,-8028 4707.28,-8028 4707.28,-7992 5128.28,-7992 5128.28,-8028"/>
+<text text-anchor="middle" x="4917.78" y="-8006.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 64]| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 43&#45;&gt;45 -->
+<g id="edge12" class="edge">
+<title>43&#45;&gt;45</title>
+<path fill="none" stroke="black" d="M4915.52,-8063.7C4915.85,-8055.98 4916.25,-8046.71 4916.62,-8038.11"/>
+<polygon fill="black" stroke="black" points="4920.12,-8038.25 4917.05,-8028.1 4913.12,-8037.95 4920.12,-8038.25"/>
+</g>
+<!-- 57 -->
+<g id="node43" class="node">
+<title>57</title>
+<polygon fill="none" stroke="black" points="5529.28,-7956 5360.28,-7956 5360.28,-7920 5529.28,-7920 5529.28,-7956"/>
+<text text-anchor="middle" x="5444.78" y="-7934.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 45&#45;&gt;57 -->
+<g id="edge25" class="edge">
+<title>45&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M5046.02,-7991.97C5139.86,-7979.5 5265.09,-7962.87 5349.99,-7951.59"/>
+<polygon fill="black" stroke="black" points="5350.69,-7955.03 5360.14,-7950.24 5349.77,-7948.09 5350.69,-7955.03"/>
+</g>
+<!-- 297 -->
+<g id="node223" class="node">
+<title>297</title>
+<polygon fill="none" stroke="black" points="5102.28,-6732 4891.28,-6732 4891.28,-6696 5102.28,-6696 5102.28,-6732"/>
+<text text-anchor="middle" x="4996.78" y="-6710.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 45&#45;&gt;297 -->
+<g id="edge274" class="edge">
+<title>45&#45;&gt;297</title>
+<path fill="none" stroke="black" d="M4942.13,-7991.87C4973.15,-7967.84 5022.78,-7920.96 5022.78,-7867 5022.78,-7867 5022.78,-7867 5022.78,-6857 5022.78,-6816.55 5012.45,-6770.67 5004.81,-6742.28"/>
+<polygon fill="black" stroke="black" points="5008.09,-6741 5002.04,-6732.3 5001.34,-6742.87 5008.09,-6741"/>
+</g>
+<!-- 47 -->
+<g id="node34" class="node">
+<title>47</title>
+<polygon fill="none" stroke="black" points="5201.28,-8604 4740.28,-8604 4740.28,-8568 5201.28,-8568 5201.28,-8604"/>
+<text text-anchor="middle" x="4970.78" y="-8582.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 46&#45;&gt;47 -->
+<g id="edge14" class="edge">
+<title>46&#45;&gt;47</title>
+<path fill="none" stroke="black" d="M4970.78,-8639.7C4970.78,-8631.98 4970.78,-8622.71 4970.78,-8614.11"/>
+<polygon fill="black" stroke="black" points="4974.28,-8614.1 4970.78,-8604.1 4967.28,-8614.1 4974.28,-8614.1"/>
+</g>
+<!-- 48 -->
+<g id="node35" class="node">
+<title>48</title>
+<polygon fill="none" stroke="black" points="5076.28,-8532 4865.28,-8532 4865.28,-8496 5076.28,-8496 5076.28,-8532"/>
+<text text-anchor="middle" x="4970.78" y="-8510.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 47&#45;&gt;48 -->
+<g id="edge15" class="edge">
+<title>47&#45;&gt;48</title>
+<path fill="none" stroke="black" d="M4970.78,-8567.7C4970.78,-8559.98 4970.78,-8550.71 4970.78,-8542.11"/>
+<polygon fill="black" stroke="black" points="4974.28,-8542.1 4970.78,-8532.1 4967.28,-8542.1 4974.28,-8542.1"/>
+</g>
+<!-- 48&#45;&gt;49 -->
+<g id="edge17" class="edge">
+<title>48&#45;&gt;49</title>
+<path fill="none" stroke="black" d="M4954.96,-8495.7C4947.21,-8487.22 4937.74,-8476.86 4929.26,-8467.58"/>
+<polygon fill="black" stroke="black" points="4931.75,-8465.12 4922.42,-8460.1 4926.58,-8469.85 4931.75,-8465.12"/>
+</g>
+<!-- 305 -->
+<g id="node229" class="node">
+<title>305</title>
+<polygon fill="none" stroke="black" points="5262.28,-7380 5051.28,-7380 5051.28,-7344 5262.28,-7344 5262.28,-7380"/>
+<text text-anchor="middle" x="5156.78" y="-7358.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 48&#45;&gt;305 -->
+<g id="edge281" class="edge">
+<title>48&#45;&gt;305</title>
+<path fill="none" stroke="black" d="M5028.36,-8495.96C5082.98,-8475.83 5156.78,-8436.51 5156.78,-8371 5156.78,-8371 5156.78,-8371 5156.78,-7505 5156.78,-7465 5156.78,-7418.65 5156.78,-7390.08"/>
+<polygon fill="black" stroke="black" points="5160.28,-7390.05 5156.78,-7380.05 5153.28,-7390.05 5160.28,-7390.05"/>
+</g>
+<!-- 50 -->
+<g id="node37" class="node">
+<title>50</title>
+<polygon fill="none" stroke="black" points="5128.28,-8388 4685.28,-8388 4685.28,-8352 5128.28,-8352 5128.28,-8388"/>
+<text text-anchor="middle" x="4906.78" y="-8366.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 49&#45;&gt;50 -->
+<g id="edge18" class="edge">
+<title>49&#45;&gt;50</title>
+<path fill="none" stroke="black" d="M4906.78,-8423.7C4906.78,-8415.98 4906.78,-8406.71 4906.78,-8398.11"/>
+<polygon fill="black" stroke="black" points="4910.28,-8398.1 4906.78,-8388.1 4903.28,-8398.1 4910.28,-8398.1"/>
+</g>
+<!-- 50&#45;&gt;51 -->
+<g id="edge19" class="edge">
+<title>50&#45;&gt;51</title>
+<path fill="none" stroke="black" d="M5029.18,-8351.97C5137.22,-8336.93 5288.87,-8315.83 5363.95,-8305.38"/>
+<polygon fill="black" stroke="black" points="5364.6,-8308.82 5374.02,-8303.98 5363.64,-8301.89 5364.6,-8308.82"/>
+</g>
+<!-- 52 -->
+<g id="node39" class="node">
+<title>52</title>
+<polygon fill="none" stroke="black" points="5601.28,-8244 5222.28,-8244 5222.28,-8208 5601.28,-8208 5601.28,-8244"/>
+<text text-anchor="middle" x="5411.78" y="-8222.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 51&#45;&gt;52 -->
+<g id="edge21" class="edge">
+<title>51&#45;&gt;52</title>
+<path fill="none" stroke="black" d="M5410.27,-8279.7C5410.49,-8271.98 5410.76,-8262.71 5411.01,-8254.11"/>
+<polygon fill="black" stroke="black" points="5414.5,-8254.2 5411.29,-8244.1 5407.51,-8254 5414.5,-8254.2"/>
+</g>
+<!-- 53 -->
+<g id="node40" class="node">
+<title>53</title>
+<polygon fill="none" stroke="black" points="5531.28,-8172 5302.28,-8172 5302.28,-8136 5531.28,-8136 5531.28,-8172"/>
+<text text-anchor="middle" x="5416.78" y="-8150.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 3, 1])</text>
+</g>
+<!-- 52&#45;&gt;53 -->
+<g id="edge22" class="edge">
+<title>52&#45;&gt;53</title>
+<path fill="none" stroke="black" d="M5413.02,-8207.7C5413.57,-8199.98 5414.23,-8190.71 5414.84,-8182.11"/>
+<polygon fill="black" stroke="black" points="5418.34,-8182.33 5415.56,-8172.1 5411.35,-8181.83 5418.34,-8182.33"/>
+</g>
+<!-- 55 -->
+<g id="node41" class="node">
+<title>55</title>
+<polygon fill="none" stroke="black" points="5646.28,-8100 5225.28,-8100 5225.28,-8064 5646.28,-8064 5646.28,-8100"/>
+<text text-anchor="middle" x="5435.78" y="-8078.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 64 14]| newshape=[&#45;1, 64, 14], reverse=0)</text>
+</g>
+<!-- 53&#45;&gt;55 -->
+<g id="edge23" class="edge">
+<title>53&#45;&gt;55</title>
+<path fill="none" stroke="black" d="M5421.48,-8135.7C5423.59,-8127.9 5426.14,-8118.51 5428.5,-8109.83"/>
+<polygon fill="black" stroke="black" points="5431.9,-8110.67 5431.14,-8100.1 5425.14,-8108.84 5431.9,-8110.67"/>
+</g>
+<!-- 56 -->
+<g id="node42" class="node">
+<title>56</title>
+<polygon fill="none" stroke="black" points="5550.28,-8028 5339.28,-8028 5339.28,-7992 5550.28,-7992 5550.28,-8028"/>
+<text text-anchor="middle" x="5444.78" y="-8006.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 55&#45;&gt;56 -->
+<g id="edge24" class="edge">
+<title>55&#45;&gt;56</title>
+<path fill="none" stroke="black" d="M5438,-8063.7C5439,-8055.98 5440.19,-8046.71 5441.29,-8038.11"/>
+<polygon fill="black" stroke="black" points="5444.78,-8038.47 5442.58,-8028.1 5437.83,-8037.58 5444.78,-8038.47"/>
+</g>
+<!-- 56&#45;&gt;57 -->
+<g id="edge26" class="edge">
+<title>56&#45;&gt;57</title>
+<path fill="none" stroke="black" d="M5444.78,-7991.7C5444.78,-7983.98 5444.78,-7974.71 5444.78,-7966.11"/>
+<polygon fill="black" stroke="black" points="5448.28,-7966.1 5444.78,-7956.1 5441.28,-7966.1 5448.28,-7966.1"/>
+</g>
+<!-- 309 -->
+<g id="node233" class="node">
+<title>309</title>
+<polygon fill="none" stroke="black" points="5753.28,-6516 5542.28,-6516 5542.28,-6480 5753.28,-6480 5753.28,-6516"/>
+<text text-anchor="middle" x="5647.78" y="-6494.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 56&#45;&gt;309 -->
+<g id="edge288" class="edge">
+<title>56&#45;&gt;309</title>
+<path fill="none" stroke="black" d="M5550.69,-8005.89C5646.87,-7996.74 5773.78,-7965.92 5773.78,-7867 5773.78,-7867 5773.78,-7867 5773.78,-6641 5773.78,-6588.21 5724.46,-6545.9 5687.54,-6521.61"/>
+<polygon fill="black" stroke="black" points="5689.24,-6518.54 5678.93,-6516.12 5685.48,-6524.45 5689.24,-6518.54"/>
+</g>
+<!-- 59 -->
+<g id="node44" class="node">
+<title>59</title>
+<polygon fill="none" stroke="black" points="5671.28,-7884 5292.28,-7884 5292.28,-7848 5671.28,-7848 5671.28,-7884"/>
+<text text-anchor="middle" x="5481.78" y="-7862.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 12 14 14]| newshape=..., reverse=0)</text>
+</g>
+<!-- 57&#45;&gt;59 -->
+<g id="edge27" class="edge">
+<title>57&#45;&gt;59</title>
+<path fill="none" stroke="black" d="M5453.93,-7919.7C5458.18,-7911.64 5463.34,-7901.89 5468.05,-7892.98"/>
+<polygon fill="black" stroke="black" points="5471.16,-7894.58 5472.74,-7884.1 5464.97,-7891.31 5471.16,-7894.58"/>
+</g>
+<!-- 61 -->
+<g id="node45" class="node">
+<title>61</title>
+<polygon fill="none" stroke="black" points="5745.28,-7812 5638.28,-7812 5638.28,-7776 5745.28,-7776 5745.28,-7812"/>
+<text text-anchor="middle" x="5691.78" y="-7790.3" font-family="Times,serif" font-size="14.00">divide(·, 8.0)</text>
+</g>
+<!-- 59&#45;&gt;61 -->
+<g id="edge28" class="edge">
+<title>59&#45;&gt;61</title>
+<path fill="none" stroke="black" d="M5532.88,-7847.97C5562.35,-7838.14 5599.59,-7825.73 5630.61,-7815.39"/>
+<polygon fill="black" stroke="black" points="5632.07,-7818.59 5640.45,-7812.11 5629.85,-7811.95 5632.07,-7818.59"/>
+</g>
+<!-- 61&#45;&gt;62 -->
+<g id="edge29" class="edge">
+<title>61&#45;&gt;62</title>
+<path fill="none" stroke="black" d="M5638.04,-7778.13C5635.25,-7777.4 5632.48,-7776.69 5629.78,-7776 5567.24,-7760.13 5494.55,-7743.47 5448.61,-7733.16"/>
+<polygon fill="black" stroke="black" points="5449.18,-7729.7 5438.66,-7730.93 5447.65,-7736.53 5449.18,-7729.7"/>
+</g>
+<!-- 63 -->
+<g id="node47" class="node">
+<title>63</title>
+<polygon fill="none" stroke="black" points="5456.78,-7668 5282.78,-7668 5282.78,-7632 5456.78,-7632 5456.78,-7668"/>
+<text text-anchor="middle" x="5369.78" y="-7646.3" font-family="Times,serif" font-size="14.00">nn.softmax(·| axis=&#45;1)</text>
+</g>
+<!-- 62&#45;&gt;63 -->
+<g id="edge31" class="edge">
+<title>62&#45;&gt;63</title>
+<path fill="none" stroke="black" d="M5394.62,-7703.7C5390.87,-7695.73 5386.33,-7686.1 5382.16,-7677.26"/>
+<polygon fill="black" stroke="black" points="5385.27,-7675.66 5377.84,-7668.1 5378.94,-7678.64 5385.27,-7675.66"/>
+</g>
+<!-- 66 -->
+<g id="node49" class="node">
+<title>66</title>
+<polygon fill="none" stroke="black" points="4834.78,-7596 4730.78,-7596 4730.78,-7560 4834.78,-7560 4834.78,-7596"/>
+<text text-anchor="middle" x="4782.78" y="-7574.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 63&#45;&gt;66 -->
+<g id="edge33" class="edge">
+<title>63&#45;&gt;66</title>
+<path fill="none" stroke="black" d="M5282.53,-7638.6C5162.76,-7624.31 4950.34,-7598.98 4845.24,-7586.45"/>
+<polygon fill="black" stroke="black" points="4845.45,-7582.95 4835.11,-7585.24 4844.62,-7589.9 4845.45,-7582.95"/>
+</g>
+<!-- 289 -->
+<g id="node216" class="node">
+<title>289</title>
+<polygon fill="none" stroke="black" points="5301.78,-1404 5197.78,-1404 5197.78,-1368 5301.78,-1368 5301.78,-1404"/>
+<text text-anchor="middle" x="5249.78" y="-1382.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 63&#45;&gt;289 -->
+<g id="edge265" class="edge">
+<title>63&#45;&gt;289</title>
+<path fill="none" stroke="black" d="M5367.21,-7632C5363.42,-7605.4 5356.78,-7552.3 5356.78,-7507 5356.78,-7507 5356.78,-7507 5356.78,-1529 5356.78,-1479.13 5314.16,-1435.65 5282.72,-1410.35"/>
+<polygon fill="black" stroke="black" points="5284.63,-1407.41 5274.6,-1404.02 5280.33,-1412.93 5284.63,-1407.41"/>
+</g>
+<!-- 292 -->
+<g id="node219" class="node">
+<title>292</title>
+<polygon fill="none" stroke="black" points="5441.78,-1188 5337.78,-1188 5337.78,-1152 5441.78,-1152 5441.78,-1188"/>
+<text text-anchor="middle" x="5389.78" y="-1166.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 63&#45;&gt;292 -->
+<g id="edge270" class="edge">
+<title>63&#45;&gt;292</title>
+<path fill="none" stroke="black" d="M5383.99,-7631.86C5403.66,-7606.34 5436.78,-7556.07 5436.78,-7507 5436.78,-7507 5436.78,-7507 5436.78,-1313 5436.78,-1270.96 5417.98,-1225.29 5404.16,-1197.41"/>
+<polygon fill="black" stroke="black" points="5407.18,-1195.63 5399.52,-1188.32 5400.95,-1198.82 5407.18,-1195.63"/>
+</g>
+<!-- 65&#45;&gt;66 -->
+<g id="edge34" class="edge">
+<title>65&#45;&gt;66</title>
+<path fill="none" stroke="black" d="M4788.8,-7631.7C4787.92,-7623.98 4786.86,-7614.71 4785.88,-7606.11"/>
+<polygon fill="black" stroke="black" points="4789.35,-7605.64 4784.73,-7596.1 4782.39,-7606.44 4789.35,-7605.64"/>
+</g>
+<!-- 287 -->
+<g id="node215" class="node">
+<title>287</title>
+<polygon fill="none" stroke="black" points="4872.78,-1476 4768.78,-1476 4768.78,-1440 4872.78,-1440 4872.78,-1476"/>
+<text text-anchor="middle" x="4820.78" y="-1454.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 65&#45;&gt;287 -->
+<g id="edge263" class="edge">
+<title>65&#45;&gt;287</title>
+<path fill="none" stroke="black" d="M4813.18,-7631.87C4824.01,-7622.43 4836.26,-7609.83 4843.78,-7596 4863.1,-7560.46 4862.78,-7547.45 4862.78,-7507 4862.78,-7507 4862.78,-7507 4862.78,-1601 4862.78,-1560.64 4857.18,-1550.4 4844.78,-1512 4841.92,-1503.16 4838.02,-1493.82 4834.19,-1485.5"/>
+<polygon fill="black" stroke="black" points="4837.26,-1483.8 4829.8,-1476.27 4830.94,-1486.81 4837.26,-1483.8"/>
+</g>
+<!-- 68 -->
+<g id="node50" class="node">
+<title>68</title>
+<polygon fill="none" stroke="black" points="4266.28,-7524 3845.28,-7524 3845.28,-7488 4266.28,-7488 4266.28,-7524"/>
+<text text-anchor="middle" x="4055.78" y="-7502.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 14]| newshape=[&#45;1, 14, 14], reverse=0)</text>
+</g>
+<!-- 66&#45;&gt;68 -->
+<g id="edge35" class="edge">
+<title>66&#45;&gt;68</title>
+<path fill="none" stroke="black" d="M4730.55,-7570.38C4703.74,-7567.08 4670.53,-7563.13 4640.78,-7560 4518.09,-7547.1 4380.72,-7534.56 4269.94,-7524.91"/>
+<polygon fill="black" stroke="black" points="4269.91,-7521.4 4259.65,-7524.02 4269.31,-7528.37 4269.91,-7521.4"/>
+</g>
+<!-- 79 -->
+<g id="node61" class="node">
+<title>79</title>
+<polygon fill="none" stroke="black" points="4338.28,-7452 4169.28,-7452 4169.28,-7416 4338.28,-7416 4338.28,-7452"/>
+<text text-anchor="middle" x="4253.78" y="-7430.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 68&#45;&gt;79 -->
+<g id="edge48" class="edge">
+<title>68&#45;&gt;79</title>
+<path fill="none" stroke="black" d="M4103.96,-7487.97C4131.63,-7478.19 4166.56,-7465.84 4195.73,-7455.52"/>
+<polygon fill="black" stroke="black" points="4197.12,-7458.74 4205.38,-7452.11 4194.79,-7452.14 4197.12,-7458.74"/>
+</g>
+<!-- 271 -->
+<g id="node202" class="node">
+<title>271</title>
+<polygon fill="none" stroke="black" points="4432.28,-5940 4221.28,-5940 4221.28,-5904 4432.28,-5904 4432.28,-5940"/>
+<text text-anchor="middle" x="4326.78" y="-5918.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 68&#45;&gt;271 -->
+<g id="edge244" class="edge">
+<title>68&#45;&gt;271</title>
+<path fill="none" stroke="black" d="M4052.92,-7487.97C4047,-7448.98 4035.52,-7350.53 4055.78,-7272 4073.91,-7201.73 4086.38,-7183.81 4132.78,-7128 4214.65,-7029.52 4289.91,-7056.6 4357.78,-6948 4379.27,-6913.62 4377.78,-6899.54 4377.78,-6859 4377.78,-6859 4377.78,-6859 4377.78,-6065 4377.78,-6022.59 4357.38,-5977.02 4342.38,-5949.26"/>
+<polygon fill="black" stroke="black" points="4345.27,-5947.25 4337.35,-5940.21 4339.15,-5950.65 4345.27,-5947.25"/>
+</g>
+<!-- 70 -->
+<g id="node52" class="node">
+<title>70</title>
+<polygon fill="none" stroke="black" points="4136.28,-8100 3675.28,-8100 3675.28,-8064 4136.28,-8064 4136.28,-8100"/>
+<text text-anchor="middle" x="3905.78" y="-8078.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 69&#45;&gt;70 -->
+<g id="edge37" class="edge">
+<title>69&#45;&gt;70</title>
+<path fill="none" stroke="black" d="M3892.98,-8135.7C3894.88,-8127.9 3897.16,-8118.51 3899.26,-8109.83"/>
+<polygon fill="black" stroke="black" points="3902.67,-8110.65 3901.63,-8100.1 3895.86,-8109 3902.67,-8110.65"/>
+</g>
+<!-- 71 -->
+<g id="node53" class="node">
+<title>71</title>
+<polygon fill="none" stroke="black" points="4145.28,-8028 3934.28,-8028 3934.28,-7992 4145.28,-7992 4145.28,-8028"/>
+<text text-anchor="middle" x="4039.78" y="-8006.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 70&#45;&gt;71 -->
+<g id="edge38" class="edge">
+<title>70&#45;&gt;71</title>
+<path fill="none" stroke="black" d="M3938.56,-8063.88C3956.4,-8054.55 3978.68,-8042.92 3997.78,-8032.94"/>
+<polygon fill="black" stroke="black" points="3999.63,-8035.92 4006.87,-8028.19 3996.39,-8029.72 3999.63,-8035.92"/>
+</g>
+<!-- 71&#45;&gt;72 -->
+<g id="edge40" class="edge">
+<title>71&#45;&gt;72</title>
+<path fill="none" stroke="black" d="M4072.56,-7991.88C4090.4,-7982.55 4112.68,-7970.92 4131.78,-7960.94"/>
+<polygon fill="black" stroke="black" points="4133.63,-7963.92 4140.87,-7956.19 4130.39,-7957.72 4133.63,-7963.92"/>
+</g>
+<!-- 279 -->
+<g id="node208" class="node">
+<title>279</title>
+<polygon fill="none" stroke="black" points="4700.28,-6660 4489.28,-6660 4489.28,-6624 4700.28,-6624 4700.28,-6660"/>
+<text text-anchor="middle" x="4594.78" y="-6638.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 71&#45;&gt;279 -->
+<g id="edge251" class="edge">
+<title>71&#45;&gt;279</title>
+<path fill="none" stroke="black" d="M4040.06,-7991.7C4040.96,-7961.79 4044.67,-7898.75 4059.78,-7848 4100.46,-7711.41 4099.88,-7660.65 4200.78,-7560 4226.67,-7534.17 4244.58,-7544.62 4274.78,-7524 4294.26,-7510.7 4294.67,-7501.83 4313.78,-7488 4447.25,-7391.45 4539.8,-7440.42 4637.78,-7308 4661.84,-7275.49 4656.78,-7259.45 4656.78,-7219 4656.78,-7219 4656.78,-7219 4656.78,-6785 4656.78,-6741.5 4632.04,-6696.29 4613.81,-6668.89"/>
+<polygon fill="black" stroke="black" points="4616.47,-6666.58 4607.93,-6660.31 4610.7,-6670.54 4616.47,-6666.58"/>
+</g>
+<!-- 73 -->
+<g id="node55" class="node">
+<title>73</title>
+<polygon fill="none" stroke="black" points="4512.28,-7884 4069.28,-7884 4069.28,-7848 4512.28,-7848 4512.28,-7884"/>
+<text text-anchor="middle" x="4290.78" y="-7862.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 72&#45;&gt;73 -->
+<g id="edge41" class="edge">
+<title>72&#45;&gt;73</title>
+<path fill="none" stroke="black" d="M4202.4,-7919.88C4217.7,-7910.72 4236.72,-7899.34 4253.2,-7889.48"/>
+<polygon fill="black" stroke="black" points="4255.26,-7892.33 4262.05,-7884.19 4251.67,-7886.32 4255.26,-7892.33"/>
+</g>
+<!-- 73&#45;&gt;74 -->
+<g id="edge42" class="edge">
+<title>73&#45;&gt;74</title>
+<path fill="none" stroke="black" d="M4312.53,-7847.7C4323.62,-7838.88 4337.25,-7828.03 4349.27,-7818.47"/>
+<polygon fill="black" stroke="black" points="4351.63,-7821.07 4357.28,-7812.1 4347.27,-7815.59 4351.63,-7821.07"/>
+</g>
+<!-- 75 -->
+<g id="node57" class="node">
+<title>75</title>
+<polygon fill="none" stroke="black" points="4568.28,-7740 4189.28,-7740 4189.28,-7704 4568.28,-7704 4568.28,-7740"/>
+<text text-anchor="middle" x="4378.78" y="-7718.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 14 12 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 74&#45;&gt;75 -->
+<g id="edge44" class="edge">
+<title>74&#45;&gt;75</title>
+<path fill="none" stroke="black" d="M4378.78,-7775.7C4378.78,-7767.98 4378.78,-7758.71 4378.78,-7750.11"/>
+<polygon fill="black" stroke="black" points="4382.28,-7750.1 4378.78,-7740.1 4375.28,-7750.1 4382.28,-7750.1"/>
+</g>
+<!-- 76 -->
+<g id="node58" class="node">
+<title>76</title>
+<polygon fill="none" stroke="black" points="4508.28,-7668 4279.28,-7668 4279.28,-7632 4508.28,-7632 4508.28,-7668"/>
+<text text-anchor="middle" x="4393.78" y="-7646.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 75&#45;&gt;76 -->
+<g id="edge45" class="edge">
+<title>75&#45;&gt;76</title>
+<path fill="none" stroke="black" d="M4382.49,-7703.7C4384.14,-7695.98 4386.13,-7686.71 4387.97,-7678.11"/>
+<polygon fill="black" stroke="black" points="4391.44,-7678.62 4390.11,-7668.1 4384.6,-7677.15 4391.44,-7678.62"/>
+</g>
+<!-- 77 -->
+<g id="node59" class="node">
+<title>77</title>
+<polygon fill="none" stroke="black" points="4631.28,-7596 4210.28,-7596 4210.28,-7560 4631.28,-7560 4631.28,-7596"/>
+<text text-anchor="middle" x="4420.78" y="-7574.3" font-family="Times,serif" font-size="14.00">reshape(·, [&#45;1 14 64]| newshape=[&#45;1, 14, 64], reverse=0)</text>
+</g>
+<!-- 76&#45;&gt;77 -->
+<g id="edge46" class="edge">
+<title>76&#45;&gt;77</title>
+<path fill="none" stroke="black" d="M4400.45,-7631.7C4403.49,-7623.81 4407.16,-7614.3 4410.54,-7605.55"/>
+<polygon fill="black" stroke="black" points="4413.85,-7606.69 4414.18,-7596.1 4407.32,-7604.17 4413.85,-7606.69"/>
+</g>
+<!-- 78 -->
+<g id="node60" class="node">
+<title>78</title>
+<polygon fill="none" stroke="black" points="4534.28,-7524 4323.28,-7524 4323.28,-7488 4534.28,-7488 4534.28,-7524"/>
+<text text-anchor="middle" x="4428.78" y="-7502.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 77&#45;&gt;78 -->
+<g id="edge47" class="edge">
+<title>77&#45;&gt;78</title>
+<path fill="none" stroke="black" d="M4422.76,-7559.7C4423.64,-7551.98 4424.7,-7542.71 4425.68,-7534.11"/>
+<polygon fill="black" stroke="black" points="4429.17,-7534.44 4426.82,-7524.1 4422.21,-7533.64 4429.17,-7534.44"/>
+</g>
+<!-- 78&#45;&gt;79 -->
+<g id="edge49" class="edge">
+<title>78&#45;&gt;79</title>
+<path fill="none" stroke="black" d="M4385.97,-7487.88C4361.83,-7478.22 4331.5,-7466.09 4305.97,-7455.88"/>
+<polygon fill="black" stroke="black" points="4307.04,-7452.53 4296.45,-7452.07 4304.44,-7459.03 4307.04,-7452.53"/>
+</g>
+<!-- 283 -->
+<g id="node212" class="node">
+<title>283</title>
+<polygon fill="none" stroke="black" points="4834.28,-6444 4623.28,-6444 4623.28,-6408 4834.28,-6408 4834.28,-6444"/>
+<text text-anchor="middle" x="4728.78" y="-6422.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 78&#45;&gt;283 -->
+<g id="edge258" class="edge">
+<title>78&#45;&gt;283</title>
+<path fill="none" stroke="black" d="M4534.55,-7497.48C4620.92,-7484.88 4728.78,-7451.45 4728.78,-7363 4728.78,-7363 4728.78,-7363 4728.78,-6569 4728.78,-6529 4728.78,-6482.65 4728.78,-6454.08"/>
+<polygon fill="black" stroke="black" points="4732.28,-6454.05 4728.78,-6444.05 4725.28,-6454.05 4732.28,-6454.05"/>
+</g>
+<!-- 81 -->
+<g id="node62" class="node">
+<title>81</title>
+<polygon fill="none" stroke="black" points="4443.28,-7380 4064.28,-7380 4064.28,-7344 4443.28,-7344 4443.28,-7380"/>
+<text text-anchor="middle" x="4253.78" y="-7358.3" font-family="Times,serif" font-size="14.00">reshape(·, [ 1 12 14 64]| newshape=..., reverse=0)</text>
+</g>
+<!-- 79&#45;&gt;81 -->
+<g id="edge50" class="edge">
+<title>79&#45;&gt;81</title>
+<path fill="none" stroke="black" d="M4253.78,-7415.7C4253.78,-7407.98 4253.78,-7398.71 4253.78,-7390.11"/>
+<polygon fill="black" stroke="black" points="4257.28,-7390.1 4253.78,-7380.1 4250.28,-7390.1 4257.28,-7390.1"/>
+</g>
+<!-- 82 -->
+<g id="node63" class="node">
+<title>82</title>
+<polygon fill="none" stroke="black" points="3993.28,-7308 3764.28,-7308 3764.28,-7272 3993.28,-7272 3993.28,-7308"/>
+<text text-anchor="middle" x="3878.78" y="-7286.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1, 3])</text>
+</g>
+<!-- 81&#45;&gt;82 -->
+<g id="edge51" class="edge">
+<title>81&#45;&gt;82</title>
+<path fill="none" stroke="black" d="M4162.52,-7343.97C4107.41,-7333.68 4037.06,-7320.55 3980.21,-7309.93"/>
+<polygon fill="black" stroke="black" points="3980.59,-7306.44 3970.12,-7308.05 3979.31,-7313.33 3980.59,-7306.44"/>
+</g>
+<!-- 83 -->
+<g id="node64" class="node">
+<title>83</title>
+<polygon fill="none" stroke="black" points="3903.28,-7236 3838.28,-7236 3838.28,-7200 3903.28,-7200 3903.28,-7236"/>
+<text text-anchor="middle" x="3870.78" y="-7214.3" font-family="Times,serif" font-size="14.00">copy(·)</text>
+</g>
+<!-- 82&#45;&gt;83 -->
+<g id="edge52" class="edge">
+<title>82&#45;&gt;83</title>
+<path fill="none" stroke="black" d="M3876.8,-7271.7C3875.92,-7263.98 3874.86,-7254.71 3873.88,-7246.11"/>
+<polygon fill="black" stroke="black" points="3877.35,-7245.64 3872.73,-7236.1 3870.39,-7246.44 3877.35,-7245.64"/>
+</g>
+<!-- 84 -->
+<g id="node65" class="node">
+<title>84</title>
+<polygon fill="none" stroke="black" points="4085.28,-7164 3642.28,-7164 3642.28,-7128 4085.28,-7128 4085.28,-7164"/>
+<text text-anchor="middle" x="3863.78" y="-7142.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 83&#45;&gt;84 -->
+<g id="edge53" class="edge">
+<title>83&#45;&gt;84</title>
+<path fill="none" stroke="black" d="M3869.05,-7199.7C3868.28,-7191.98 3867.35,-7182.71 3866.49,-7174.11"/>
+<polygon fill="black" stroke="black" points="3869.97,-7173.71 3865.49,-7164.1 3863,-7174.4 3869.97,-7173.71"/>
+</g>
+<!-- 85 -->
+<g id="node66" class="node">
+<title>85</title>
+<polygon fill="none" stroke="black" points="4087.78,-7092 3639.78,-7092 3639.78,-7056 4087.78,-7056 4087.78,-7092"/>
+<text text-anchor="middle" x="3863.78" y="-7070.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 84&#45;&gt;85 -->
+<g id="edge54" class="edge">
+<title>84&#45;&gt;85</title>
+<path fill="none" stroke="black" d="M3863.78,-7127.7C3863.78,-7119.98 3863.78,-7110.71 3863.78,-7102.11"/>
+<polygon fill="black" stroke="black" points="3867.28,-7102.1 3863.78,-7092.1 3860.28,-7102.1 3867.28,-7102.1"/>
+</g>
+<!-- 89 -->
+<g id="node70" class="node">
+<title>89</title>
+<polygon fill="none" stroke="black" points="4212.28,-7020 4043.28,-7020 4043.28,-6984 4212.28,-6984 4212.28,-7020"/>
+<text text-anchor="middle" x="4127.78" y="-6998.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 85&#45;&gt;89 -->
+<g id="edge58" class="edge">
+<title>85&#45;&gt;89</title>
+<path fill="none" stroke="black" d="M3928.02,-7055.97C3965.87,-7045.93 4013.92,-7033.19 4053.4,-7022.72"/>
+<polygon fill="black" stroke="black" points="4054.48,-7026.06 4063.25,-7020.11 4052.68,-7019.29 4054.48,-7026.06"/>
+</g>
+<!-- 348 -->
+<g id="node263" class="node">
+<title>348</title>
+<polygon fill="none" stroke="black" points="3125.28,-4500 2914.28,-4500 2914.28,-4464 3125.28,-4464 3125.28,-4500"/>
+<text text-anchor="middle" x="3019.78" y="-4478.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 85&#45;&gt;348 -->
+<g id="edge325" class="edge">
+<title>85&#45;&gt;348</title>
+<path fill="none" stroke="black" d="M3639.58,-7060.98C3361.25,-7042.98 2922.78,-7003.18 2922.78,-6931 2922.78,-6931 2922.78,-6931 2922.78,-6425 2922.78,-6329.89 2923.78,-6306.11 2923.78,-6211 2923.78,-6211 2923.78,-6211 2923.78,-4625 2923.78,-4576.85 2962.12,-4532.85 2990.34,-4506.98"/>
+<polygon fill="black" stroke="black" points="2992.85,-4509.43 2998,-4500.17 2988.2,-4504.2 2992.85,-4509.43"/>
+</g>
+<!-- 87 -->
+<g id="node68" class="node">
+<title>87</title>
+<polygon fill="none" stroke="black" points="4603.28,-7164 4142.28,-7164 4142.28,-7128 4603.28,-7128 4603.28,-7164"/>
+<text text-anchor="middle" x="4372.78" y="-7142.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 768 768]| newshape=[&#45;1, 768, 768], reverse=0)</text>
+</g>
+<!-- 86&#45;&gt;87 -->
+<g id="edge56" class="edge">
+<title>86&#45;&gt;87</title>
+<path fill="none" stroke="black" d="M4362.99,-7199.7C4364.43,-7191.98 4366.15,-7182.71 4367.74,-7174.11"/>
+<polygon fill="black" stroke="black" points="4371.22,-7174.58 4369.6,-7164.1 4364.34,-7173.3 4371.22,-7174.58"/>
+</g>
+<!-- 88 -->
+<g id="node69" class="node">
+<title>88</title>
+<polygon fill="none" stroke="black" points="4478.28,-7092 4267.28,-7092 4267.28,-7056 4478.28,-7056 4478.28,-7092"/>
+<text text-anchor="middle" x="4372.78" y="-7070.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 87&#45;&gt;88 -->
+<g id="edge57" class="edge">
+<title>87&#45;&gt;88</title>
+<path fill="none" stroke="black" d="M4372.78,-7127.7C4372.78,-7119.98 4372.78,-7110.71 4372.78,-7102.11"/>
+<polygon fill="black" stroke="black" points="4376.28,-7102.1 4372.78,-7092.1 4369.28,-7102.1 4376.28,-7102.1"/>
+</g>
+<!-- 88&#45;&gt;89 -->
+<g id="edge59" class="edge">
+<title>88&#45;&gt;89</title>
+<path fill="none" stroke="black" d="M4313.16,-7055.97C4278.33,-7046.01 4234.2,-7033.41 4197.74,-7022.99"/>
+<polygon fill="black" stroke="black" points="4198.24,-7019.49 4187.67,-7020.11 4196.32,-7026.22 4198.24,-7019.49"/>
+</g>
+<!-- 263 -->
+<g id="node196" class="node">
+<title>263</title>
+<polygon fill="none" stroke="black" points="4566.28,-5868 4355.28,-5868 4355.28,-5832 4566.28,-5832 4566.28,-5868"/>
+<text text-anchor="middle" x="4460.78" y="-5846.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 88&#45;&gt;263 -->
+<g id="edge237" class="edge">
+<title>88&#45;&gt;263</title>
+<path fill="none" stroke="black" d="M4392.31,-7055.86C4418.24,-7031.11 4460.78,-6982.63 4460.78,-6931 4460.78,-6931 4460.78,-6931 4460.78,-5993 4460.78,-5953 4460.78,-5906.65 4460.78,-5878.08"/>
+<polygon fill="black" stroke="black" points="4464.28,-5878.05 4460.78,-5868.05 4457.28,-5878.05 4464.28,-5878.05"/>
+</g>
+<!-- 90 -->
+<g id="node71" class="node">
+<title>90</title>
+<polygon fill="none" stroke="black" points="4349.28,-6948 3906.28,-6948 3906.28,-6912 4349.28,-6912 4349.28,-6948"/>
+<text text-anchor="middle" x="4127.78" y="-6926.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;1 &#160;14 768]| newshape=[1, 14, 768], reverse=0)</text>
+</g>
+<!-- 89&#45;&gt;90 -->
+<g id="edge60" class="edge">
+<title>89&#45;&gt;90</title>
+<path fill="none" stroke="black" d="M4127.78,-6983.7C4127.78,-6975.98 4127.78,-6966.71 4127.78,-6958.11"/>
+<polygon fill="black" stroke="black" points="4131.28,-6958.1 4127.78,-6948.1 4124.28,-6958.1 4131.28,-6958.1"/>
+</g>
+<!-- 90&#45;&gt;91 -->
+<g id="edge61" class="edge">
+<title>90&#45;&gt;91</title>
+<path fill="none" stroke="black" d="M4009.03,-6911.97C3904.96,-6897.04 3759.19,-6876.13 3685.82,-6865.6"/>
+<polygon fill="black" stroke="black" points="3685.93,-6862.08 3675.54,-6864.13 3684.94,-6869.01 3685.93,-6862.08"/>
+</g>
+<!-- 94 -->
+<g id="node74" class="node">
+<title>94</title>
+<polygon fill="none" stroke="black" points="3578.78,-6804 3474.78,-6804 3474.78,-6768 3578.78,-6768 3578.78,-6804"/>
+<text text-anchor="middle" x="3526.78" y="-6782.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 91&#45;&gt;94 -->
+<g id="edge64" class="edge">
+<title>91&#45;&gt;94</title>
+<path fill="none" stroke="black" d="M3612.14,-6839.88C3597.36,-6830.72 3578.99,-6819.34 3563.08,-6809.48"/>
+<polygon fill="black" stroke="black" points="3564.87,-6806.48 3554.53,-6804.19 3561.19,-6812.43 3564.87,-6806.48"/>
+</g>
+<!-- 93&#45;&gt;94 -->
+<g id="edge65" class="edge">
+<title>93&#45;&gt;94</title>
+<path fill="none" stroke="black" d="M3230.88,-6840.94C3302.34,-6827.91 3401.28,-6809.88 3464.67,-6798.32"/>
+<polygon fill="black" stroke="black" points="3465.31,-6801.76 3474.52,-6796.53 3464.06,-6794.88 3465.31,-6801.76"/>
+</g>
+<!-- 262 -->
+<g id="node195" class="node">
+<title>262</title>
+<polygon fill="none" stroke="black" points="3092.78,-1980 2988.78,-1980 2988.78,-1944 3092.78,-1944 3092.78,-1980"/>
+<text text-anchor="middle" x="3040.78" y="-1958.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 93&#45;&gt;262 -->
+<g id="edge236" class="edge">
+<title>93&#45;&gt;262</title>
+<path fill="none" stroke="black" d="M3104.95,-6839.9C3062.28,-6817.72 2998.78,-6774.69 2998.78,-6715 2998.78,-6715 2998.78,-6715 2998.78,-6641 2998.78,-6443.4 2885.78,-6408.6 2885.78,-6211 2885.78,-6211 2885.78,-6211 2885.78,-4553 2885.78,-4512.55 2883.91,-4498.65 2904.78,-4464 2917.64,-4442.65 2936.23,-4450.09 2947.78,-4428 2995.96,-4335.86 2968.78,-4298.98 2968.78,-4195 2968.78,-4195 2968.78,-4195 2968.78,-4121 2968.78,-3640.82 3040.78,-3523.18 3040.78,-3043 3040.78,-3043 3040.78,-3043 [...]
+<polygon fill="black" stroke="black" points="3044.28,-1990.05 3040.78,-1980.05 3037.28,-1990.05 3044.28,-1990.05"/>
+</g>
+<!-- 94&#45;&gt;95 -->
+<g id="edge66" class="edge">
+<title>94&#45;&gt;95</title>
+<path fill="none" stroke="black" d="M3526.78,-6767.7C3526.78,-6759.98 3526.78,-6750.71 3526.78,-6742.11"/>
+<polygon fill="black" stroke="black" points="3530.28,-6742.1 3526.78,-6732.1 3523.28,-6742.1 3530.28,-6742.1"/>
+</g>
+<!-- 97 -->
+<g id="node76" class="node">
+<title>97</title>
+<polygon fill="none" stroke="black" points="3690.28,-6660 3363.28,-6660 3363.28,-6624 3690.28,-6624 3690.28,-6660"/>
+<text text-anchor="middle" x="3526.78" y="-6638.3" font-family="Times,serif" font-size="14.00">mean(·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 95&#45;&gt;97 -->
+<g id="edge68" class="edge">
+<title>95&#45;&gt;97</title>
+<path fill="none" stroke="black" d="M3526.78,-6695.7C3526.78,-6687.98 3526.78,-6678.71 3526.78,-6670.11"/>
+<polygon fill="black" stroke="black" points="3530.28,-6670.1 3526.78,-6660.1 3523.28,-6670.1 3530.28,-6670.1"/>
+</g>
+<!-- 98 -->
+<g id="node77" class="node">
+<title>98</title>
+<polygon fill="none" stroke="black" points="3388.28,-6588 3283.28,-6588 3283.28,-6552 3388.28,-6552 3388.28,-6588"/>
+<text text-anchor="middle" x="3335.78" y="-6566.3" font-family="Times,serif" font-size="14.00">subtract(·, ·)</text>
+</g>
+<!-- 95&#45;&gt;98 -->
+<g id="edge69" class="edge">
+<title>95&#45;&gt;98</title>
+<path fill="none" stroke="black" d="M3491.01,-6712.3C3451.53,-6709.51 3388.91,-6698.75 3353.78,-6660 3338.68,-6643.35 3334.82,-6617.63 3334.36,-6598.12"/>
+<polygon fill="black" stroke="black" points="3337.86,-6598.03 3334.4,-6588.01 3330.86,-6598 3337.86,-6598.03"/>
+</g>
+<!-- 101 -->
+<g id="node78" class="node">
+<title>101</title>
+<polygon fill="none" stroke="black" points="3833.78,-6588 3471.78,-6588 3471.78,-6552 3833.78,-6552 3833.78,-6588"/>
+<text text-anchor="middle" x="3652.78" y="-6566.3" font-family="Times,serif" font-size="14.00">variance(·, ·| axis=[&#45;1], keepdims=1, exclude=0)</text>
+</g>
+<!-- 95&#45;&gt;101 -->
+<g id="edge71" class="edge">
+<title>95&#45;&gt;101</title>
+<path fill="none" stroke="black" d="M3562.37,-6713.5C3602.89,-6711.92 3667.51,-6702.36 3698.78,-6660 3713.6,-6639.92 3697.36,-6614.39 3680.21,-6595.82"/>
+<polygon fill="black" stroke="black" points="3682.39,-6593.04 3672.9,-6588.33 3677.38,-6597.93 3682.39,-6593.04"/>
+</g>
+<!-- 237 -->
+<g id="node178" class="node">
+<title>237</title>
+<polygon fill="none" stroke="black" points="3599.78,-2484 3495.78,-2484 3495.78,-2448 3599.78,-2448 3599.78,-2484"/>
+<text text-anchor="middle" x="3547.78" y="-2462.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 95&#45;&gt;237 -->
+<g id="edge213" class="edge">
+<title>95&#45;&gt;237</title>
+<path fill="none" stroke="black" d="M3562.54,-6707.27C3655.9,-6691.3 3899.78,-6642.72 3899.78,-6571 3899.78,-6571 3899.78,-6571 3899.78,-2825 3899.78,-2781.74 3691.57,-2547.02 3657.78,-2520 3641.78,-2507.2 3622.18,-2496.52 3603.98,-2488.17"/>
+<polygon fill="black" stroke="black" points="3605.18,-2484.87 3594.62,-2484.02 3602.34,-2491.27 3605.18,-2484.87"/>
+</g>
+<!-- 97&#45;&gt;98 -->
+<g id="edge70" class="edge">
+<title>97&#45;&gt;98</title>
+<path fill="none" stroke="black" d="M3480.06,-6623.88C3453.48,-6614.14 3420.02,-6601.87 3392.01,-6591.61"/>
+<polygon fill="black" stroke="black" points="3392.95,-6588.22 3382.35,-6588.07 3390.54,-6594.8 3392.95,-6588.22"/>
+</g>
+<!-- 97&#45;&gt;101 -->
+<g id="edge72" class="edge">
+<title>97&#45;&gt;101</title>
+<path fill="none" stroke="black" d="M3557.6,-6623.88C3574.23,-6614.64 3594.94,-6603.13 3612.8,-6593.21"/>
+<polygon fill="black" stroke="black" points="3614.8,-6596.11 3621.84,-6588.19 3611.4,-6589.99 3614.8,-6596.11"/>
+</g>
+<!-- 242 -->
+<g id="node181" class="node">
+<title>242</title>
+<polygon fill="none" stroke="black" points="3771.78,-2484 3667.78,-2484 3667.78,-2448 3771.78,-2448 3771.78,-2484"/>
+<text text-anchor="middle" x="3719.78" y="-2462.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 97&#45;&gt;242 -->
+<g id="edge216" class="edge">
+<title>97&#45;&gt;242</title>
+<path fill="none" stroke="black" d="M3690.31,-6630.3C3755.27,-6622.71 3819.97,-6609.82 3842.78,-6588 3872,-6560.04 3861.78,-6539.45 3861.78,-6499 3861.78,-6499 3861.78,-6499 3861.78,-2609 3861.78,-2568.46 3867.62,-2551.24 3841.78,-2520 3833.16,-2509.58 3806.89,-2497.6 3780.93,-2487.67"/>
+<polygon fill="black" stroke="black" points="3782,-2484.34 3771.41,-2484.11 3779.55,-2490.89 3782,-2484.34"/>
+</g>
+<!-- 105 -->
+<g id="node81" class="node">
+<title>105</title>
+<polygon fill="none" stroke="black" points="3040.28,-6372 2951.28,-6372 2951.28,-6336 3040.28,-6336 3040.28,-6372"/>
+<text text-anchor="middle" x="2995.78" y="-6350.3" font-family="Times,serif" font-size="14.00">divide(·, ·)</text>
+</g>
+<!-- 98&#45;&gt;105 -->
+<g id="edge75" class="edge">
+<title>98&#45;&gt;105</title>
+<path fill="none" stroke="black" d="M3308.53,-6551.85C3247.6,-6513.5 3100.35,-6420.82 3031.89,-6377.73"/>
+<polygon fill="black" stroke="black" points="3033.49,-6374.6 3023.16,-6372.23 3029.76,-6380.52 3033.49,-6374.6"/>
+</g>
+<!-- 102 -->
+<g id="node79" class="node">
+<title>102</title>
+<polygon fill="none" stroke="black" points="3798.28,-6516 3737.28,-6516 3737.28,-6480 3798.28,-6480 3798.28,-6516"/>
+<text text-anchor="middle" x="3767.78" y="-6494.3" font-family="Times,serif" font-size="14.00">sqrt(·)</text>
+</g>
+<!-- 101&#45;&gt;102 -->
+<g id="edge73" class="edge">
+<title>101&#45;&gt;102</title>
+<path fill="none" stroke="black" d="M3680.91,-6551.88C3695.95,-6542.72 3714.65,-6531.34 3730.84,-6521.48"/>
+<polygon fill="black" stroke="black" points="3732.82,-6524.38 3739.54,-6516.19 3729.18,-6518.4 3732.82,-6524.38"/>
+</g>
+<!-- 233 -->
+<g id="node175" class="node">
+<title>233</title>
+<polygon fill="none" stroke="black" points="3394.28,-4932 3305.28,-4932 3305.28,-4896 3394.28,-4896 3394.28,-4932"/>
+<text text-anchor="middle" x="3349.78" y="-4910.3" font-family="Times,serif" font-size="14.00">power(·, ·)</text>
+</g>
+<!-- 101&#45;&gt;233 -->
+<g id="edge207" class="edge">
+<title>101&#45;&gt;233</title>
+<path fill="none" stroke="black" d="M3662.1,-6551.63C3675.28,-6525.41 3697.78,-6473.81 3697.78,-6427 3697.78,-6427 3697.78,-6427 3697.78,-5201 3697.78,-5080.64 3660.83,-5033.37 3559.78,-4968 3512.77,-4937.59 3449.4,-4924.62 3404.61,-4919.1"/>
+<polygon fill="black" stroke="black" points="3404.83,-4915.6 3394.49,-4917.94 3404.03,-4922.56 3404.83,-4915.6"/>
+</g>
+<!-- 104 -->
+<g id="node80" class="node">
+<title>104</title>
+<polygon fill="none" stroke="black" points="3833.28,-6444 3726.28,-6444 3726.28,-6408 3833.28,-6408 3833.28,-6444"/>
+<text text-anchor="middle" x="3779.78" y="-6422.3" font-family="Times,serif" font-size="14.00">add(·, 1e&#45;12)</text>
+</g>
+<!-- 102&#45;&gt;104 -->
+<g id="edge74" class="edge">
+<title>102&#45;&gt;104</title>
+<path fill="none" stroke="black" d="M3770.75,-6479.7C3772.07,-6471.98 3773.66,-6462.71 3775.13,-6454.11"/>
+<polygon fill="black" stroke="black" points="3778.61,-6454.55 3776.85,-6444.1 3771.71,-6453.37 3778.61,-6454.55"/>
+</g>
+<!-- 104&#45;&gt;105 -->
+<g id="edge76" class="edge">
+<title>104&#45;&gt;105</title>
+<path fill="none" stroke="black" d="M3726.01,-6425.45C3605.8,-6425.53 3306.35,-6420.54 3050.26,-6372.03"/>
+<polygon fill="black" stroke="black" points="3050.88,-6368.59 3040.4,-6370.14 3049.56,-6375.46 3050.88,-6368.59"/>
+</g>
+<!-- 227 -->
+<g id="node171" class="node">
+<title>227</title>
+<polygon fill="none" stroke="black" points="3630.28,-2916 3541.28,-2916 3541.28,-2880 3630.28,-2880 3630.28,-2916"/>
+<text text-anchor="middle" x="3585.78" y="-2894.3" font-family="Times,serif" font-size="14.00">divide(·, ·)</text>
+</g>
+<!-- 104&#45;&gt;227 -->
+<g id="edge203" class="edge">
+<title>104&#45;&gt;227</title>
+<path fill="none" stroke="black" d="M3770.66,-6407.61C3757.78,-6381.36 3735.78,-6329.71 3735.78,-6283 3735.78,-6283 3735.78,-6283 3735.78,-3041 3735.78,-2984.42 3678.8,-2943.65 3635.08,-2920.68"/>
+<polygon fill="black" stroke="black" points="3636.48,-2917.47 3625.98,-2916.06 3633.3,-2923.71 3636.48,-2917.47"/>
+</g>
+<!-- 249 -->
+<g id="node186" class="node">
+<title>249</title>
+<polygon fill="none" stroke="black" points="3829.28,-2916 3740.28,-2916 3740.28,-2880 3829.28,-2880 3829.28,-2916"/>
+<text text-anchor="middle" x="3784.78" y="-2894.3" font-family="Times,serif" font-size="14.00">divide(·, ·)</text>
+</g>
+<!-- 104&#45;&gt;249 -->
+<g id="edge224" class="edge">
+<title>104&#45;&gt;249</title>
+<path fill="none" stroke="black" d="M3786.26,-6407.79C3795.61,-6381.36 3811.78,-6328.99 3811.78,-6283 3811.78,-6283 3811.78,-6283 3811.78,-3041 3811.78,-3000.5 3801.05,-2954.63 3793.12,-2926.26"/>
+<polygon fill="black" stroke="black" points="3796.38,-2924.93 3790.24,-2916.29 3789.65,-2926.87 3796.38,-2924.93"/>
+</g>
+<!-- 105&#45;&gt;106 -->
+<g id="edge77" class="edge">
+<title>105&#45;&gt;106</title>
+<path fill="none" stroke="black" d="M3040.28,-6338.61C3043.15,-6337.72 3046.01,-6336.84 3048.78,-6336 3089.54,-6323.63 3135.6,-6310.43 3171.56,-6300.3"/>
+<polygon fill="black" stroke="black" points="3172.74,-6303.6 3181.42,-6297.53 3170.85,-6296.86 3172.74,-6303.6"/>
+</g>
+<!-- 226 -->
+<g id="node170" class="node">
+<title>226</title>
+<polygon fill="none" stroke="black" points="3637.78,-2988 3533.78,-2988 3533.78,-2952 3637.78,-2952 3637.78,-2988"/>
+<text text-anchor="middle" x="3585.78" y="-2966.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 105&#45;&gt;226 -->
+<g id="edge201" class="edge">
+<title>105&#45;&gt;226</title>
+<path fill="none" stroke="black" d="M3040.45,-6338.1C3043.26,-6337.34 3046.05,-6336.63 3048.78,-6336 3250.99,-6289.11 3498.78,-6418.57 3498.78,-6211 3498.78,-6211 3498.78,-6211 3498.78,-3113 3498.78,-3066.16 3533.62,-3021.74 3559.21,-2995.43"/>
+<polygon fill="black" stroke="black" points="3561.86,-2997.72 3566.47,-2988.18 3556.92,-2992.77 3561.86,-2997.72"/>
+</g>
+<!-- 356 -->
+<g id="node269" class="node">
+<title>356</title>
+<polygon fill="none" stroke="black" points="2977.78,-3060 2873.78,-3060 2873.78,-3024 2977.78,-3024 2977.78,-3060"/>
+<text text-anchor="middle" x="2925.78" y="-3038.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 105&#45;&gt;356 -->
+<g id="edge333" class="edge">
+<title>105&#45;&gt;356</title>
+<path fill="none" stroke="black" d="M3039.37,-6335.97C3086.14,-6314.5 3153.78,-6272.78 3153.78,-6211 3153.78,-6211 3153.78,-6211 3153.78,-3329 3153.78,-3208.01 3027.48,-3108.69 2962.88,-3065.8"/>
+<polygon fill="black" stroke="black" points="2964.48,-3062.67 2954.2,-3060.13 2960.65,-3068.53 2964.48,-3062.67"/>
+</g>
+<!-- 106&#45;&gt;107 -->
+<g id="edge79" class="edge">
+<title>106&#45;&gt;107</title>
+<path fill="none" stroke="black" d="M3181.49,-6278.91C3055.58,-6273.56 2725.73,-6257.66 2451.78,-6228 2429.23,-6225.56 2404.3,-6221.95 2383.52,-6218.68"/>
+<polygon fill="black" stroke="black" points="2383.8,-6215.18 2373.37,-6217.06 2382.69,-6222.09 2383.8,-6215.18"/>
+</g>
+<!-- 108 -->
+<g id="node84" class="node">
+<title>108</title>
+<polygon fill="none" stroke="black" points="2627.78,-6156 2179.78,-6156 2179.78,-6120 2627.78,-6120 2627.78,-6156"/>
+<text text-anchor="middle" x="2403.78" y="-6134.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#45;1 &#160;14 768]| newshape=[&#45;1, 14, 768], reverse=0)</text>
+</g>
+<!-- 107&#45;&gt;108 -->
+<g id="edge81" class="edge">
+<title>107&#45;&gt;108</title>
+<path fill="none" stroke="black" d="M2354.09,-6191.7C2362.09,-6183.22 2371.85,-6172.86 2380.6,-6163.58"/>
+<polygon fill="black" stroke="black" points="2383.34,-6165.78 2387.65,-6156.1 2378.25,-6160.98 2383.34,-6165.78"/>
+</g>
+<!-- 138 -->
+<g id="node105" class="node">
+<title>138</title>
+<polygon fill="none" stroke="black" points="2059.28,-5148 1988.28,-5148 1988.28,-5112 2059.28,-5112 2059.28,-5148"/>
+<text text-anchor="middle" x="2023.78" y="-5126.3" font-family="Times,serif" font-size="14.00">add(·, ·)</text>
+</g>
+<!-- 107&#45;&gt;138 -->
+<g id="edge109" class="edge">
+<title>107&#45;&gt;138</title>
+<path fill="none" stroke="black" d="M2302.14,-6209.1C2262.71,-6207.06 2200.78,-6196.99 2170.78,-6156 2161.33,-6143.09 2165.39,-6135.06 2170.78,-6120 2259.81,-5871.18 2514.78,-5899.27 2514.78,-5635 2514.78,-5635 2514.78,-5635 2514.78,-5489 2514.78,-5269.67 2190.94,-5170.47 2069.26,-5140.94"/>
+<polygon fill="black" stroke="black" points="2070.03,-5137.52 2059.49,-5138.61 2068.4,-5144.33 2070.03,-5137.52"/>
+</g>
+<!-- 113 -->
+<g id="node88" class="node">
+<title>113</title>
+<polygon fill="none" stroke="black" points="2669.28,-6084 2500.28,-6084 2500.28,-6048 2669.28,-6048 2669.28,-6084"/>
+<text text-anchor="middle" x="2584.78" y="-6062.3" font-family="Times,serif" font-size="14.00">nn.batch_matmul(·, ·)</text>
+</g>
+<!-- 108&#45;&gt;113 -->
+<g id="edge85" class="edge">
+<title>108&#45;&gt;113</title>
+<path fill="none" stroke="black" d="M2448.06,-6119.88C2473.13,-6110.18 2504.68,-6097.98 2531.14,-6087.74"/>
+<polygon fill="black" stroke="black" points="2532.58,-6090.94 2540.65,-6084.07 2530.06,-6084.41 2532.58,-6090.94"/>
+</g>
+<!-- 362 -->
+<g id="node273" class="node">
+<title>362</title>
+<polygon fill="none" stroke="black" points="2773.28,-3924 2562.28,-3924 2562.28,-3888 2773.28,-3888 2773.28,-3924"/>
+<text text-anchor="middle" x="2667.78" y="-3902.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 108&#45;&gt;362 -->
+<g id="edge337" class="edge">
+<title>108&#45;&gt;362</title>
+<path fill="none" stroke="black" d="M2388.3,-6119.7C2363.96,-6090.31 2321.41,-6028.68 2340.78,-5976 2349.84,-5951.36 2480.34,-5818.54 2493.78,-5796 2591.84,-5631.55 2571.78,-5336.31 2571.78,-5275 2571.78,-5275 2571.78,-5275 2571.78,-4049 2571.78,-4000.85 2610.12,-3956.85 2638.34,-3930.98"/>
+<polygon fill="black" stroke="black" points="2640.85,-3933.43 2646,-3924.17 2636.2,-3928.2 2640.85,-3933.43"/>
+</g>
+<!-- 111 -->
+<g id="node86" class="node">
+<title>111</title>
+<polygon fill="none" stroke="black" points="2858.28,-6228 2461.28,-6228 2461.28,-6192 2858.28,-6192 2858.28,-6228"/>
+<text text-anchor="middle" x="2659.78" y="-6206.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#45;1 &#160;768 3072]| newshape=..., reverse=0)</text>
+</g>
+<!-- 109&#45;&gt;111 -->
+<g id="edge83" class="edge">
+<title>109&#45;&gt;111</title>
+<path fill="none" stroke="black" d="M2659.78,-6263.7C2659.78,-6255.98 2659.78,-6246.71 2659.78,-6238.11"/>
+<polygon fill="black" stroke="black" points="2663.28,-6238.1 2659.78,-6228.1 2656.28,-6238.1 2663.28,-6238.1"/>
+</g>
+<!-- 112 -->
+<g id="node87" class="node">
+<title>112</title>
+<polygon fill="none" stroke="black" points="2857.28,-6156 2646.28,-6156 2646.28,-6120 2857.28,-6120 2857.28,-6156"/>
+<text text-anchor="middle" x="2751.78" y="-6134.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 111&#45;&gt;112 -->
+<g id="edge84" class="edge">
+<title>111&#45;&gt;112</title>
+<path fill="none" stroke="black" d="M2682.52,-6191.7C2694.22,-6182.8 2708.64,-6171.82 2721.29,-6162.2"/>
+<polygon fill="black" stroke="black" points="2723.46,-6164.94 2729.3,-6156.1 2719.22,-6159.37 2723.46,-6164.94"/>
+</g>
+<!-- 112&#45;&gt;113 -->
+<g id="edge86" class="edge">
+<title>112&#45;&gt;113</title>
+<path fill="none" stroke="black" d="M2710.93,-6119.88C2688,-6110.26 2659.2,-6098.19 2634.91,-6088.01"/>
+<polygon fill="black" stroke="black" points="2636.08,-6084.71 2625.5,-6084.07 2633.37,-6091.16 2636.08,-6084.71"/>
+</g>
+<!-- 222 -->
+<g id="node166" class="node">
+<title>222</title>
+<polygon fill="none" stroke="black" points="2939.28,-4428 2728.28,-4428 2728.28,-4392 2939.28,-4392 2939.28,-4428"/>
+<text text-anchor="middle" x="2833.78" y="-4406.3" font-family="Times,serif" font-size="14.00">transpose(·| axes=[0, 2, 1])</text>
+</g>
+<!-- 112&#45;&gt;222 -->
+<g id="edge193" class="edge">
+<title>112&#45;&gt;222</title>
+<path fill="none" stroke="black" d="M2773.56,-6119.83C2801.89,-6095.39 2847.78,-6047.67 2847.78,-5995 2847.78,-5995 2847.78,-5995 2847.78,-4553 2847.78,-4513.02 2842.22,-4467.02 2838.1,-4438.47"/>
+<polygon fill="black" stroke="black" points="2841.54,-4437.81 2836.61,-4428.43 2834.62,-4438.84 2841.54,-4437.81"/>
+</g>
+<!-- 115 -->
+<g id="node89" class="node">
+<title>115</title>
+<polygon fill="none" stroke="black" points="2819.78,-6012 2349.78,-6012 2349.78,-5976 2819.78,-5976 2819.78,-6012"/>
+<text text-anchor="middle" x="2584.78" y="-5990.3" font-family="Times,serif" font-size="14.00">reshape(·, [ &#160;&#160;1 &#160;&#160;14 3072]| newshape=[1, 14, 3072], reverse=0)</text>
+</g>
+<!-- 113&#45;&gt;115 -->
+<g id="edge87" class="edge">
+<title>113&#45;&gt;115</title>
+<path fill="none" stroke="black" d="M2584.78,-6047.7C2584.78,-6039.98 2584.78,-6030.71 2584.78,-6022.11"/>
+<polygon fill="black" stroke="black" points="2588.28,-6022.1 2584.78,-6012.1 2581.28,-6022.1 2588.28,-6022.1"/>
+</g>
+<!-- 115&#45;&gt;116 -->
+<g id="edge88" class="edge">
+<title>115&#45;&gt;116</title>
+<path fill="none" stroke="black" d="M2407.87,-5975.97C2241.03,-5959.9 2002.24,-5936.91 1903.35,-5927.39"/>
+<polygon fill="black" stroke="black" points="1903.68,-5923.9 1893.39,-5926.43 1903,-5930.87 1903.68,-5923.9"/>
+</g>
+<!-- 120 -->
+<g id="node91" class="node">
+<title>120</title>
+<polygon fill="none" stroke="black" points="1769.28,-5868 1584.28,-5868 1584.28,-5832 1769.28,-5832 1769.28,-5868"/>
+<text text-anchor="middle" x="1676.78" y="-5846.3" font-family="Times,serif" font-size="14.00">multiply(·, 0.70710677)</text>
+</g>
+<!-- 116&#45;&gt;120 -->
+<g id="edge90" class="edge">
+<title>116&#45;&gt;120</title>
+<path fill="none" stroke="black" d="M1822.02,-5907.17C1795.8,-5897.03 1759.77,-5883.1 1730.16,-5871.64"/>
+<polygon fill="black" stroke="black" points="1731.38,-5868.36 1720.79,-5868.02 1728.85,-5874.89 1731.38,-5868.36"/>
+</g>
+<!-- 125 -->
+<g id="node95" class="node">
+<title>125</title>
+<polygon fill="none" stroke="black" points="1848.78,-5580 1744.78,-5580 1744.78,-5544 1848.78,-5544 1848.78,-5580"/>
+<text text-anchor="middle" x="1796.78" y="-5558.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 116&#45;&gt;125 -->
+<g id="edge94" class="edge">
+<title>116&#45;&gt;125</title>
+<path fill="none" stroke="black" d="M1844.84,-5903.66C1826.94,-5877.88 1796.78,-5827.3 1796.78,-5779 1796.78,-5779 1796.78,-5779 1796.78,-5705 1796.78,-5665 1796.78,-5618.65 1796.78,-5590.08"/>
+<polygon fill="black" stroke="black" points="1800.28,-5590.05 1796.78,-5580.05 1793.28,-5590.05 1800.28,-5590.05"/>
+</g>
+<!-- 217 -->
+<g id="node161" class="node">
+<title>217</title>
+<polygon fill="none" stroke="black" points="1812.78,-3564 1708.78,-3564 1708.78,-3528 1812.78,-3528 1812.78,-3564"/>
+<text text-anchor="middle" x="1760.78" y="-3542.3" font-family="Times,serif" font-size="14.00">multiply(·, ·)</text>
+</g>
+<!-- 116&#45;&gt;217 -->
+<g id="edge186" class="edge">
+<title>116&#45;&gt;217</title>
+<path fill="none" stroke="black" d="M1885.8,-5903.94C1920.42,-5880.47 1974.78,-5834.74 1974.78,-5779 1974.78,-5779 1974.78,-5779 1974.78,-5561 1974.78,-5520.46 1982.85,-5501.25 1954.78,-5472 1826.7,-5338.5 1680.14,-5498.18 1552.78,-5364 1524.93,-5334.66 1533.78,-5315.45 1533.78,-5275 1533.78,-5275 1533.78,-5275 1533.78,-4625 1533.78,-4461.96 1453.78,-4430.04 1453.78,-4267 1453.78,-4267 1453.78,-4267 1453.78,-4049 1453.78,-4008.55 1452.68,-3995.1 1472.78,-3960 1508.1,-3898.3 1555.46,-3913 [...]
+<polygon fill="black" stroke="black" points="1712.57,-3571.81 1719.92,-3564.18 1709.41,-3565.56 1712.57,-3571.81"/>
+</g>
+<!-- 121 -->
+<g id="node92" class="node">
+<title>121</title>
+<polygon fill="none" stroke="black" points="1763.78,-5796 1709.78,-5796 1709.78,-5760 1763.78,-5760 1763.78,-5796"/>
+<text text-anchor="middle" x="1736.78" y="-5774.3" font-family="Times,serif" font-size="14.00">erf(·)</text>
+</g>
+<!-- 120&#45;&gt;121 -->
+<g id="edge91" class="edge">
+<title>120&#45;&gt;121</title>
+<path fill="none" stroke="black" d="M1691.61,-5831.7C1698.8,-5823.3 1707.58,-5813.07 1715.47,-5803.86"/>
+<polygon fill="black" stroke="black" points="1718.27,-5805.97 1722.12,-5796.1 1712.95,-5801.42 1718.27,-5805.97"/>
+</g>
+<!-- 213 -->
+<g id="node157" class="node">
+<title>213</title>
+<polygon fill="none" stroke="black" points="1648.28,-5796 1555.28,-5796 1555.28,-5760 1648.28,-5760 1648.28,-5796"/>
+<text text-anchor="middle" x="1601.78" y="-5774.3" font-family="Times,serif" font-size="14.00">negative(·)</text>
+</g>
+<!-- 120&#45;&gt;213 -->
+<g id="edge180" class="edge">
+<title>120&#45;&gt;213</title>
+<path fill="none" stroke="black" d="M1658.24,-5831.7C1648.98,-5823.05 1637.62,-5812.45 1627.52,-5803.03"/>
+<polygon fill="black" stroke="black" points="1629.8,-5800.37 1620.11,-5796.1 1625.03,-5805.49 1629.8,-5800.37"/>
+</g>
+<!-- 214 -->
+<g id="node158" class="node">
+<title>214</title>
+<polygon fill="none" stroke="black" points="1653.78,-5724 1549.78,-5724 1549.78,-5688 1653.78,-5688 1653.78,-5724"/>
... 8160 lines suppressed ...