You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by cm...@apache.org on 2019/01/13 17:18:50 UTC

[incubator-mxnet] branch master updated: Modifying clojure CNN text classification example (#13865)

This is an automated email from the ASF dual-hosted git repository.

cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e57930  Modifying clojure CNN text classification example (#13865)
0e57930 is described below

commit 0e57930011085cedf59ffe040729ea037ceeece3
Author: Kedar Bellare <ke...@gmail.com>
AuthorDate: Sun Jan 13 09:18:32 2019 -0800

    Modifying clojure CNN text classification example (#13865)
    
    * Modifying clojure CNN text classification example
    
    * Small fixes
    
    * Another minor fix
---
 .../examples/cnn-text-classification/README.md     |  38 +++-
 .../src/cnn_text_classification/classifier.clj     |  45 +++--
 .../src/cnn_text_classification/data_helper.clj    | 195 ++++++++++++++-------
 .../cnn_text_classification/classifier_test.clj    |  52 +++---
 4 files changed, 218 insertions(+), 112 deletions(-)

diff --git a/contrib/clojure-package/examples/cnn-text-classification/README.md b/contrib/clojure-package/examples/cnn-text-classification/README.md
index 86a8abb..19bb913 100644
--- a/contrib/clojure-package/examples/cnn-text-classification/README.md
+++ b/contrib/clojure-package/examples/cnn-text-classification/README.md
@@ -3,19 +3,19 @@
 An example of text classification using CNN
 
 To use you must download the MR polarity dataset and put it in the path specified in the mr-dataset-path
-The dataset can be obtained here: [https://github.com/yoonkim/CNN_sentence](https://github.com/yoonkim/CNN_sentence). The two files `rt-polarity.neg`
+The dataset can be obtained here: [CNN_sentence](https://github.com/yoonkim/CNN_sentence). The two files `rt-polarity.neg`
 and `rt-polarity.pos` must be put in a directory. For example, `data/mr-data/rt-polarity.neg`.
 
 You also must download the glove word embeddings. The suggested one to use is the smaller 50 dimension one
-`glove.6B.50d.txt` which is contained in the download file here [https://nlp.stanford.edu/projects/glove/](https://nlp.stanford.edu/projects/glove/)
+`glove.6B.50d.txt` which is contained in the download file here: [GloVe](https://nlp.stanford.edu/projects/glove/)
 
 ## Usage
 
 You can run through the repl with
-`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000})`
+`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})`
 
 or
-`JVM_OPTS="Xmx1g" lein run` (cpu)
+`JVM_OPTS="-Xmx1g" lein run` (cpu)
 
 You can control the devices you run on by doing:
 
@@ -24,10 +24,36 @@ You can control the devices you run on by doing:
 `lein run :gpu 2` - This will run on 2 gpu devices
 
 
-The max-examples only loads 1000 each of the dataset to keep the time and memory down. To run all the examples, 
-change the main to be (train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10)
+The max-examples only loads 1000 each of the dataset to keep the time and memory down. To run all the examples,
+change the main to be (train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :glove})
 
 and then run
 
 - `lein uberjar`
 - `java -Xms1024m -Xmx2048m -jar target/cnn-text-classification-0.1.0-SNAPSHOT-standalone.jar`
+
+## Usage with word2vec
+
+You can also use word2vec embeddings in order to train the text classification model.
+Before training, you will need to download [GoogleNews-vectors-negative300.bin](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing) first.
+Once you've downloaded the embeddings (which are in a gzipped format),
+you'll need to unzip them and place them in the `contrib/clojure-package/data` directory.
+
+Then you can run training on a subset of examples through the repl using:
+```
+(train-convnet {:embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :word2vec})
+```
+Note that loading word2vec embeddings consumes memory and takes some time.
+
+You can also train them using `JVM_OPTS="-Xmx8g" lein run` once you've modified
+the parameters to `train-convnet` (see above) in `src/cnn_text_classification/classifier.clj`.
+In order to run training with word2vec on the complete data set, you will need to run:
+```
+(train-convnet {:embedding-size 300 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :word2vec})
+```
+You should be able to achieve an accuracy of `~0.78` using the parameters above.
+
+## Usage with learned embeddings
+
+Lastly, similar to the python CNN text classification example, you can learn the embeddings based on training data.
+This can be achieved by setting `:pretrained-embedding nil` (or omitting that parameter altogether).
diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj
index 94fd4f5..3c0288c 100644
--- a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj
+++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj
@@ -30,34 +30,48 @@
 
 (def data-dir "data/")
 (def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path
-(def glove-file-path "data/glove/glove.6B.50d.txt")
 (def num-filter 100)
 (def num-label 2)
 (def dropout 0.5)
 
-
-
 (when-not (.exists (io/file (str data-dir)))
   (do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh")))
 
-(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}]
+(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size vocab-size embedding-size pretrained-embedding]}]
   (println "Shuffling the data and splitting into training and test sets")
   (println {:sentence-count sentence-count
             :sentence-size sentence-size
-            :embedding-size embedding-size})
+            :vocab-size vocab-size
+            :embedding-size embedding-size
+            :pretrained-embedding pretrained-embedding})
   (let [shuffled (shuffle (map #(vector %1 %2) data label))
         train-num (- (count shuffled) test-num)
         training (into [] (take train-num shuffled))
-        test (into [] (drop train-num shuffled))]
+        test (into [] (drop train-num shuffled))
+        ;; has to be channel x y
+        train-data-shape (if pretrained-embedding
+                           [train-num 1 sentence-size embedding-size]
+                           [train-num 1 sentence-size])
+        ;; has to be channel x y
+        test-data-shape (if pretrained-embedding
+                           [test-num 1 sentence-size embedding-size]
+                           [test-num 1 sentence-size])]
     {:training {:data  (ndarray/array (into [] (flatten (mapv first training)))
-                                      [train-num 1 sentence-size embedding-size]) ;; has to be channel x y
+                                      train-data-shape)
                 :label (ndarray/array (into [] (flatten (mapv last  training)))
                                       [train-num])}
      :test {:data  (ndarray/array (into [] (flatten (mapv first test)))
-                                  [test-num 1 sentence-size embedding-size]) ;; has to be channel x y
+                                  test-data-shape)
             :label (ndarray/array (into [] (flatten (mapv last  test)))
                                   [test-num])}}))
 
+(defn get-data-symbol [num-embed sentence-size batch-size vocab-size pretrained-embedding]
+  (if pretrained-embedding
+    (sym/variable "data")
+    (as-> (sym/variable "data") data
+      (sym/embedding "vocab_embed" {:data data :input-dim vocab-size :output-dim num-embed})
+      (sym/reshape {:data data :target-shape [batch-size 1 sentence-size num-embed]}))))
+
 (defn make-filter-layers [{:keys [input-x num-embed sentence-size] :as config}
                           filter-size]
   (as-> (sym/convolution {:data input-x
@@ -71,9 +85,9 @@
 
 ;;; convnet with multiple filter sizes
 ;; from Convolutional Neural Networks for Sentence Classification by Yoon Kim
-(defn get-multi-filter-convnet [num-embed sentence-size batch-size]
+(defn get-multi-filter-convnet [num-embed sentence-size batch-size vocab-size pretrained-embedding]
   (let [filter-list [3 4 5]
-        input-x (sym/variable "data")
+        input-x (get-data-symbol num-embed sentence-size batch-size vocab-size pretrained-embedding)
         polled-outputs (mapv #(make-filter-layers {:input-x input-x :num-embed num-embed :sentence-size sentence-size} %) filter-list)
         total-filters (* num-filter (count filter-list))
         concat (sym/concat "concat" nil polled-outputs {:dim 1})
@@ -82,10 +96,11 @@
         fc (sym/fully-connected  "fc1" {:data hdrop :num-hidden num-label})]
     (sym/softmax-output "softmax" {:data fc})))
 
-(defn train-convnet [{:keys [devs embedding-size batch-size test-size num-epoch max-examples]}]
-  (let [glove (data-helper/load-glove glove-file-path) ;; you can also use word2vec
-        ms-dataset (data-helper/load-ms-with-embeddings mr-dataset-path embedding-size glove max-examples)
+(defn train-convnet [{:keys [devs embedding-size batch-size test-size
+                             num-epoch max-examples pretrained-embedding]}]
+  (let [ms-dataset (data-helper/load-ms-with-embeddings mr-dataset-path max-examples embedding-size {:pretrained-embedding pretrained-embedding})
         sentence-size (:sentence-size ms-dataset)
+        vocab-size (:vocab-size ms-dataset)
         shuffled (shuffle-data test-size ms-dataset)
         train-data (mx-io/ndarray-iter [(get-in shuffled [:training :data])]
                                        {:label [(get-in shuffled [:training :label])]
@@ -97,7 +112,7 @@
                                        :label-name "softmax_label"
                                        :data-batch-size batch-size
                                        :last-batch-handle "pad"})]
-    (let [mod (m/module (get-multi-filter-convnet embedding-size sentence-size batch-size) {:contexts devs})]
+    (let [mod (m/module (get-multi-filter-convnet embedding-size sentence-size batch-size vocab-size pretrained-embedding) {:contexts devs})]
       (println "Getting ready to train for " num-epoch " epochs")
       (println "===========")
       (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch
@@ -111,7 +126,7 @@
   ;;; omit max-examples if you want to run all the examples in the movie review dataset
     ;; to limit mem consumption set to something like 1000 and adjust test size to 100
     (println "Running with context devices of" devs)
-    (train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000})
+    (train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})
     ;; runs all the examples
     #_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10})))
 
diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj
index 7966521..82ba130 100644
--- a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj
+++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj
@@ -21,53 +21,84 @@
             [org.apache.clojure-mxnet.context :as context]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
             [org.apache.clojure-mxnet.random :as random])
-  (:import (java.io DataInputStream))
+  (:import (java.io DataInputStream)
+           (java.nio ByteBuffer ByteOrder))
   (:gen-class))
 
 (def w2v-file-path "../../data/GoogleNews-vectors-negative300.bin") ;; the word2vec file path
-(def max-vectors 100) ;; If you are using word2vec embeddings and you want to only load part of them
-
-(defn r-string [dis]
-  (let [max-size 50
-        bs (byte-array max-size)
-        sb (new StringBuilder)]
-    (loop [b (.readByte dis)
-           i 0]
-      (if (and (not= 32 b) (not= 10 b))
-        (do (aset bs i b)
-            (if (= 49 i)
-              (do (.append sb (new String bs))
-                  (recur (.readByte dis) 0))
-              (recur (.readByte dis) (inc i))))
-        (.append sb (new String bs 0 i))))
-    (.toString sb)))
-
-(defn get-float [b]
-  (-> 0
-      (bit-or (bit-shift-left (bit-and (aget b 0) 0xff) 0))
-      (bit-or (bit-shift-left (bit-and (aget b 1) 0xff) 8))
-      (bit-or (bit-shift-left (bit-and (aget b 2) 0xff) 16))
-      (bit-or (bit-shift-left (bit-and (aget b 3) 0xff) 24))))
+(def EOS "</s>")  ;; end of sentence word
+
+(defn glove-file-path
+  "Returns the file path to GloVe embedding of the input size"
+  [embedding-size]
+  (format "data/glove/glove.6B.%dd.txt" embedding-size))
+
+(defn r-string
+  "Reads a string from the given DataInputStream `dis` until a space or newline is reached."
+  [dis]
+  (loop [b (.readByte dis)
+         bs []]
+    (if (and (not= 32 b) (not= 10 b))
+      (recur (.readByte dis) (conj bs b))
+      (new String (byte-array bs)))))
+
+(defn get-float [bs]
+  (-> (ByteBuffer/wrap bs)
+      (.order ByteOrder/LITTLE_ENDIAN)
+      (.getFloat)))
 
 (defn read-float [is]
   (let [bs (byte-array 4)]
     (do (.read is bs)
         (get-float bs))))
 
-(defn load-google-model [path]
-  (println "Loading the word2vec model from binary ...")
-  (with-open [bis (io/input-stream path)
-              dis (new DataInputStream bis)]
-    (let [word-size (Integer/parseInt (r-string dis))
-          dim  (Integer/parseInt (r-string dis))
-          _  (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors)
-          word2vec (reduce (fn [r _]
-                             (assoc r (r-string dis)
-                                    (mapv (fn [_] (read-float dis)) (range dim))))
-                           {}
-                           (range max-vectors))]
-      (println "Finished")
-      {:num-embed dim :word2vec word2vec})))
+(defn- load-w2v-vectors
+  "Lazily loads the word2vec vectors given a data input stream `dis`,
+  number of words `nwords` and dimensionality `embedding-size`."
+  [dis embedding-size num-vectors]
+  (if (= 0 num-vectors)
+    (list)
+    (let [word (r-string dis)
+          vect (mapv (fn [_] (read-float dis)) (range embedding-size))]
+      (cons [word vect] (lazy-seq (load-w2v-vectors dis embedding-size (dec num-vectors)))))))
+
+(defn load-word2vec-model
+  "Loads the word2vec model stored in a binary format from the given `path`.
+  By default only the first 100 embeddings are loaded."
+  ([path embedding-size opts]
+   (println "Loading the word2vec model from binary ...")
+   (with-open [bis (io/input-stream path)
+               dis (new DataInputStream bis)]
+     (let [word-size (Integer/parseInt (r-string dis))
+           dim  (Integer/parseInt (r-string dis))
+           {:keys [max-vectors vocab] :or {max-vectors word-size}} opts
+           _  (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors)
+           _ (if (not= embedding-size dim)
+               (throw (ex-info "Mismatch in embedding size"
+                      {:input-embedding-size embedding-size
+                       :word2vec-embedding-size dim})))
+           vectors (load-w2v-vectors dis dim max-vectors)
+           word2vec (if vocab
+                      (->> vectors
+                           (filter (fn [[w _]] (contains? vocab w)))
+                           (into {}))
+                      (->> vectors
+                           (take max-vectors)
+                           (into {})))]
+       (println "Finished")
+       {:num-embed dim :word2vec word2vec})))
+  ([path embedding-size]
+   (load-word2vec-model path embedding-size {:max-vectors 100})))
+
+(defn read-text-embedding-pairs [rdr]
+  (for [^String line (line-seq rdr)
+        :let [fields (.split line " ")]]
+    [(aget fields 0)
+     (mapv #(Float/parseFloat ^String %) (rest fields))]))
+
+(defn load-glove [glove-file-path]
+  (println "Loading the glove pre-trained word embeddings from " glove-file-path)
+  (into {} (read-text-embedding-pairs (io/reader glove-file-path))))
 
 (defn clean-str [s]
   (-> s
@@ -84,9 +115,12 @@
       (string/replace #"\)" " ) ")
       (string/replace #"\?" " ? ")
       (string/replace #" {2,}" " ")
-      (string/trim)));; Loads MR polarity data from files, splits the data into words and generates labels.
- ;; Returns split sentences and labels.
-(defn load-mr-data-and-labels [path max-examples]
+      (string/trim)))
+
+(defn load-mr-data-and-labels
+  "Loads MR polarity data from files, splits the data into words and generates labels. 
+  Returns split sentences and labels."
+  [path max-examples]
   (println "Loading all the movie reviews from " path)
   (let [positive-examples (mapv #(string/trim %) (-> (slurp (str path "/rt-polarity.pos"))
                                                      (string/split #"\n")))
@@ -104,41 +138,68 @@
         negative-labels (mapv (constantly 0) negative-examples)]
     {:sentences x-text :labels (into positive-labels negative-labels)}))
 
-;; Pads all sentences to the same length. The length is defined by the longest sentence.
-;; Returns padded sentences.
-(defn pad-sentences [sentences]
-  (let [padding-word "<s>"
+(defn pad-sentences
+  "Pads all sentences to the same length where the length is defined by the longest sentence. Returns padded sentences."
+  [sentences]
+  (let [padding-word EOS
         sequence-len (apply max (mapv count sentences))]
     (mapv (fn [s] (let [diff (- sequence-len (count s))]
                     (if (pos? diff)
                       (into s (repeat diff padding-word))
                       s)))
-          sentences)));; Map sentences and labels to vectors based on a pretrained embeddings
-(defn build-input-data-with-embeddings [sentences embedding-size embeddings]
-  (mapv (fn [sent]
-          (mapv (fn [word] (or (get embeddings word)
-                               (ndarray/->vec (random/uniform -0.25 0.25 [embedding-size]))))
-                sent))
-        sentences))
-
-(defn load-ms-with-embeddings [path embedding-size embeddings max-examples]
-  (println "Translating the movie review words into the embeddings")
+          sentences)))
+
+(defn build-vocab-embeddings
+  "Returns the subset of `embeddings` for words from the `vocab`.
+  Embeddings for words not in the vocabulary are initialized randomly
+  from a uniform distribution."
+  [vocab embedding-size embeddings]
+  (into {}
+        (mapv (fn [[word _]]
+                [word (or (get embeddings word)
+                          (ndarray/->vec (random/uniform -0.25 0.25 [embedding-size])))])
+              vocab)))
+
+(defn build-input-data-with-embeddings
+  "Map sentences and labels to vectors based on a pretrained embeddings."
+  [sentences embeddings]
+  (mapv (fn [sent] (mapv #(embeddings %) sent)) sentences))
+
+(defn build-vocab
+  "Creates a vocabulary for the data set based on frequency of words.
+  Returns a map from words to unique indices."
+  [sentences]
+  (let [words (flatten sentences)
+        wc (reduce
+            (fn [m w] (update-in m [w] (fnil inc 0)))
+            {}
+            words)
+        sorted-wc (sort-by second > wc)
+        sorted-w (map first sorted-wc)]
+    (into {} (map vector sorted-w (range (count sorted-w))))))
+
+(defn load-ms-with-embeddings
+  "Loads the movie review sentences data set for the given
+  `:pretrained-embedding` (e.g. `nil`, `:glove` or `:word2vec`)"
+  [path max-examples embedding-size {:keys [pretrained-embedding]
+                                     :or {pretrained-embedding nil}
+                                     :as opts}]
   (let [{:keys [sentences labels]} (load-mr-data-and-labels path max-examples)
         sentences-padded  (pad-sentences sentences)
-        data (build-input-data-with-embeddings sentences-padded embedding-size embeddings)]
+        vocab (build-vocab sentences-padded)
+        vocab-embeddings (case pretrained-embedding
+                           :glove (->> (load-glove (glove-file-path embedding-size))
+                                       (build-vocab-embeddings vocab embedding-size))
+                           :word2vec (->> (load-word2vec-model w2v-file-path embedding-size {:vocab vocab})
+                                          (:word2vec)
+                                          (build-vocab-embeddings vocab embedding-size))
+                           vocab)
+        data (build-input-data-with-embeddings sentences-padded vocab-embeddings)]
     {:data data
      :label labels
      :sentence-count (count data)
      :sentence-size (count (first data))
-     :embedding-size embedding-size}))
-
-(defn read-text-embedding-pairs [rdr]
-  (for [^String line (line-seq rdr)
-        :let [fields (.split line " ")]]
-    [(aget fields 0)
-     (mapv #(Double/parseDouble ^String %) (rest fields))]))
-
-(defn load-glove [glove-file-path]
-  (println "Loading the glove pre-trained word embeddings from " glove-file-path)
-  (into {} (read-text-embedding-pairs (io/reader glove-file-path))))
+     :embedding-size embedding-size
+     :vocab-size (count vocab)
+     :pretrained-embedding pretrained-embedding}))
 
diff --git a/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj b/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj
index 918a46f..744307e 100644
--- a/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj
+++ b/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj
@@ -16,29 +16,33 @@
 ;;
 
 (ns cnn-text-classification.classifier-test
-	(:require 
-		[clojure.test :refer :all]
-		[org.apache.clojure-mxnet.module :as module]
-		[org.apache.clojure-mxnet.ndarray :as ndarray]
-		[org.apache.clojure-mxnet.util :as util]
-		[org.apache.clojure-mxnet.context :as context]
-		[cnn-text-classification.classifier :as classifier]))
+  (:require [clojure.test :refer :all]
+            [org.apache.clojure-mxnet.module :as module]
+            [org.apache.clojure-mxnet.ndarray :as ndarray]
+            [org.apache.clojure-mxnet.util :as util]
+            [org.apache.clojure-mxnet.context :as context]
+            [cnn-text-classification.classifier :as classifier]))
 
-;
-; The one and unique classifier test
-;
-(deftest classifier-test
-	(let [train
-    (classifier/train-convnet 
-    	{:devs [(context/default-context)]
-         :embedding-size 50 
-         :batch-size 10 
-         :test-size 100 
-         :num-epoch 1 
-         :max-examples 1000})]
+(deftest classifier-with-embeddings-test
+  (let [train (classifier/train-convnet
+               {:devs [(context/default-context)]
+                :embedding-size 50
+                :batch-size 10
+                :test-size 100
+                :num-epoch 1
+                :max-examples 1000
+                :pretrained-embedding :glove})]
     (is (= ["data"] (util/scala-vector->vec (module/data-names train))))
-    (is (= 20 (count (ndarray/->vec (-> train module/outputs first first)))))))
-    ;(prn (util/scala-vector->vec (data-shapes train)))	
-    ;(prn (util/scala-vector->vec (label-shapes train)))
-    ;(prn (output-names train))
-    ;(prn (output-shapes train))
\ No newline at end of file
+    (is (= 20 (count (ndarray/->vec (-> train module/outputs ffirst)))))))
+
+(deftest classifier-without-embeddings-test
+  (let [train (classifier/train-convnet
+               {:devs [(context/default-context)]
+                :embedding-size 50
+                :batch-size 10
+                :test-size 100
+                :num-epoch 1
+                :max-examples 1000
+                :pretrained-embedding nil})]
+    (is (= ["data"] (util/scala-vector->vec (module/data-names train))))
+    (is (= 20 (count (ndarray/->vec (-> train module/outputs ffirst)))))))