You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by qk...@apache.org on 2017/08/09 06:45:04 UTC
[incubator-mxnet] branch master updated: [R] im2rec in R. close
#7273 (#7389)
This is an automated email from the ASF dual-hosted git repository.
qkou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 89e3ee3 [R] im2rec in R. close #7273 (#7389)
89e3ee3 is described below
commit 89e3ee3ea7c223db8c65ddd8c94c6e787d7c52df
Author: Qiang Kou (KK) <qk...@qkou.info>
AuthorDate: Wed Aug 9 06:44:58 2017 +0000
[R] im2rec in R. close #7273 (#7389)
---
Makefile | 3 +-
R-package/R/util.R | 49 +++++-
R-package/src/Makevars | 2 +-
R-package/src/export.cc | 2 +-
R-package/src/im2rec.cc | 269 +++++++++++++++++++++++++++++++
R-package/src/im2rec.h | 42 +++++
R-package/src/mxnet.cc | 3 +
R-package/vignettes/CatsDogsFinetune.Rmd | 55 ++++---
8 files changed, 402 insertions(+), 23 deletions(-)
diff --git a/Makefile b/Makefile
index 5c7f54d..ed74214 100644
--- a/Makefile
+++ b/Makefile
@@ -379,6 +379,7 @@ rcpplint:
rpkg:
mkdir -p R-package/inst
mkdir -p R-package/inst/libs
+ cp src/io/image_recordio.h R-package/src
cp -rf lib/libmxnet.so R-package/inst/libs
mkdir -p R-package/inst/include
cp -rf include/* R-package/inst/include
@@ -442,7 +443,7 @@ clean: cyclean $(EXTRA_PACKAGES_CLEAN)
else
clean: cyclean testclean $(EXTRA_PACKAGES_CLEAN)
$(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ R-package/NAMESPACE R-package/man R-package/R/mxnet_generated.R \
- R-package/inst R-package/src/*.o R-package/src/*.so mxnet_*.tar.gz
+ R-package/inst R-package/src/image_recordio.h R-package/src/*.o R-package/src/*.so mxnet_*.tar.gz
cd $(DMLC_CORE); $(MAKE) clean; cd -
cd $(PS_PATH); $(MAKE) clean; cd -
cd $(NNVM_PATH); $(MAKE) clean; cd -
diff --git a/R-package/R/util.R b/R-package/R/util.R
index 2b292d1..acc9510 100644
--- a/R-package/R/util.R
+++ b/R-package/R/util.R
@@ -9,5 +9,52 @@ mx.util.filter.null <- function(lst) {
#'
#' @export
mxnet.export <- function(path) {
- mxnet.internal.export(path.expand(path))
+ mx.internal.export(path.expand(path))
+}
+
+#' Convert images into image recordio format
+#' @param image_lst
+#' The image lst file
+#' @param root
+#' The root folder for image files
+#' @param output_rec
+#' The output rec file
+#' @param label_width
+#' The label width in the list file. Default is 1.
+#' @param pack_label
+#' Whether to also pack multi dimenional label in the record file. Default is 0.
+#' @param new_size
+#' The shorter edge of image will be resized to the newsize.
+#' Original images will be packed by default.
+#' @param nsplit
+#' It is used for part generation, logically split the image.lst to NSPLIT parts by position.
+#' Default is 1.
+#' @param partid
+#' It is used for part generation, pack the images from the specific part in image.lst.
+#' Default is 0.
+#' @param center_crop
+#' Whether to crop the center image to make it square. Default is 0.
+#' @param quality
+#' JPEG quality for encoding (1-100, default: 95) or PNG compression for encoding (1-9, default: 3).
+#' @param color_mode
+#' Force color (1), gray image (0) or keep source unchanged (-1). Default is 1.
+#' @param unchanged
+#' Keep the original image encoding, size and color. If set to 1, it will ignore the others parameters.
+#' @param inter_method
+#' NN(0), BILINEAR(1), CUBIC(2), AREA(3), LANCZOS4(4), AUTO(9), RAND(10). Default is 1.
+#' @param encoding
+#' The encoding type for images. It can be '.jpg' or '.png'. Default is '.jpg'.
+#' @export
+im2rec <- function(image_lst, root, output_rec, label_width = 1L,
+ pack_label = 0L, new_size = -1L, nsplit = 1L,
+ partid = 0L, center_crop = 0L, quality = 95L,
+ color_mode = 1L, unchanged = 0L, inter_method = 1L,
+ encoding = ".jpg") {
+ image_lst <- path.expand(image_lst)
+ root <- path.expand(root)
+ output_rec <- path.expand(output_rec)
+ mx.internal.im2rec(image_lst, root, output_rec, label_width,
+ pack_label, new_size, nsplit, partid,
+ center_crop, quality, color_mode, unchanged,
+ inter_method, encoding)
}
diff --git a/R-package/src/Makevars b/R-package/src/Makevars
index a9cdabf..c089c09 100644
--- a/R-package/src/Makevars
+++ b/R-package/src/Makevars
@@ -1,3 +1,3 @@
-
+CXX_STD = CXX11
PKG_CPPFLAGS = -I../inst/include
PKG_LIBS = $(LAPACK_LIBS) $(BLAS_LIBS)
diff --git a/R-package/src/export.cc b/R-package/src/export.cc
index 2377a02..ef77d25 100644
--- a/R-package/src/export.cc
+++ b/R-package/src/export.cc
@@ -41,7 +41,7 @@ Exporter* Exporter::Get() {
void Exporter::InitRcppModule() {
using namespace Rcpp; // NOLINT(*)
Exporter::Get()->scope_ = ::getCurrentScope();
- function("mxnet.internal.export", &Exporter::Export,
+ function("mx.internal.export", &Exporter::Export,
Rcpp::List::create(_["path"]),
"Internal function of mxnet, used to export generated functions file.");
}
diff --git a/R-package/src/im2rec.cc b/R-package/src/im2rec.cc
new file mode 100644
index 0000000..0c6bea9
--- /dev/null
+++ b/R-package/src/im2rec.cc
@@ -0,0 +1,269 @@
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file export.h
+ * \brief Export module that takes charge of code generation and document
+ * Generation for functions exported from R-side
+ */
+
+#include <cctype>
+#include <cstring>
+#include <string>
+#include <vector>
+#include <iomanip>
+#include <sstream>
+#include <random>
+#include "dmlc/base.h"
+#include "dmlc/io.h"
+#include "dmlc/timer.h"
+#include "dmlc/logging.h"
+#include "dmlc/recordio.h"
+#include <opencv2/opencv.hpp>
+#include "image_recordio.h"
+#include "base.h"
+#include "im2rec.h"
+
+namespace mxnet {
+namespace R {
+
+int GetInterMethod(int inter_method, int old_width, int old_height,
+ int new_width, int new_height, std::mt19937& prnd) { // NOLINT(*)
+ if (inter_method == 9) {
+ if (new_width > old_width && new_height > old_height) {
+ return 2; // CV_INTER_CUBIC for enlarge
+ } else if (new_width <old_width && new_height < old_height) {
+ return 3; // CV_INTER_AREA for shrink
+ } else {
+ return 1; // CV_INTER_LINEAR for others
+ }
+ } else if (inter_method == 10) {
+ std::uniform_int_distribution<size_t> rand_uniform_int(0, 4);
+ return rand_uniform_int(prnd);
+ } else {
+ return inter_method;
+ }
+}
+
+IM2REC* IM2REC::Get() {
+ static IM2REC inst;
+ return &inst;
+}
+
+void IM2REC::InitRcppModule() {
+ using namespace Rcpp; // NOLINT(*)
+ IM2REC::Get()->scope_ = ::getCurrentScope();
+ function("mx.internal.im2rec", &IM2REC::im2rec,
+ Rcpp::List::create(_["image_lst"],
+ _["root"],
+ _["output_rec"],
+ _["label_width"],
+ _["pack_label"],
+ _["new_size"],
+ _["nsplit"],
+ _["partid"],
+ _["center_crop"],
+ _["quality"],
+ _["color_mode"],
+ _["unchanged"],
+ _["inter_method"],
+ _["encoding"]),
+ "");
+}
+
+void IM2REC::im2rec(const std::string & image_lst, const std::string & root,
+ const std::string & output_rec,
+ int label_width, int pack_label, int new_size, int nsplit,
+ int partid, int center_crop, int quality,
+ int color_mode, int unchanged,
+ int inter_method, std::string encoding) {
+ // Check parameters ranges
+ if (color_mode != -1 && color_mode != 0 && color_mode != 1) {
+ Rcpp::stop("Color mode must be -1, 0 or 1.");
+ }
+ if (encoding != std::string(".jpg") && encoding != std::string(".png")) {
+ Rcpp::stop("Encoding mode must be .jpg or .png.");
+ }
+ if (label_width <= 1 && pack_label) {
+ Rcpp::stop("pack_label can only be used when label_width > 1");
+ }
+ if (new_size > 0) {
+ LOG(INFO) << "New Image Size: Short Edge " << new_size;
+ } else {
+ LOG(INFO) << "Keep origin image size";
+ }
+ if (center_crop) {
+ LOG(INFO) << "Center cropping to square";
+ }
+ if (color_mode == 0) {
+ LOG(INFO) << "Use gray images";
+ }
+ if (color_mode == -1) {
+ LOG(INFO) << "Keep original color mode";
+ }
+ LOG(INFO) << "Encoding is " << encoding;
+
+ if (encoding == std::string(".png") && quality > 9) {
+ quality = 3;
+ }
+ if (inter_method != 1) {
+ switch (inter_method) {
+ case 0:
+ LOG(INFO) << "Use inter_method CV_INTER_NN";
+ break;
+ case 2:
+ LOG(INFO) << "Use inter_method CV_INTER_CUBIC";
+ break;
+ case 3:
+ LOG(INFO) << "Use inter_method CV_INTER_AREA";
+ break;
+ case 4:
+ LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4";
+ break;
+ case 9:
+ LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)";
+ break;
+ case 10:
+ LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)";
+ break;
+ }
+ }
+ std::random_device rd;
+ std::mt19937 prnd(rd());
+ using namespace dmlc;
+ static const size_t kBufferSize = 1 << 20UL;
+ mxnet::io::ImageRecordIO rec;
+ size_t imcnt = 0;
+ double tstart = dmlc::GetTime();
+ dmlc::InputSplit *flist =
+ dmlc::InputSplit::Create(image_lst.c_str(), partid, nsplit, "text");
+ std::ostringstream os;
+ if (nsplit == 1) {
+ os << output_rec;
+ } else {
+ os << output_rec << ".part" << std::setw(3) << std::setfill('0') << partid;
+ }
+ LOG(INFO) << "Write to output: " << os.str();
+ dmlc::Stream *fo = dmlc::Stream::Create(os.str().c_str(), "w");
+ LOG(INFO) << "Output: " << os.str();
+ dmlc::RecordIOWriter writer(fo);
+ std::string fname, path, blob;
+ std::vector<unsigned char> decode_buf;
+ std::vector<unsigned char> encode_buf;
+ std::vector<int> encode_params;
+ if (encoding == std::string(".png")) {
+ encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION);
+ encode_params.push_back(quality);
+ LOG(INFO) << "PNG encoding compression: " << quality;
+ } else {
+ encode_params.push_back(CV_IMWRITE_JPEG_QUALITY);
+ encode_params.push_back(quality);
+ LOG(INFO) << "JPEG encoding quality: " << quality;
+ }
+ dmlc::InputSplit::Blob line;
+ std::vector<float> label_buf(label_width, 0.f);
+
+ while (flist->NextRecord(&line)) {
+ std::string sline(static_cast<char*>(line.dptr), line.size);
+ std::istringstream is(sline);
+ if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue;
+ label_buf[0] = rec.header.label;
+ for (int k = 1; k < label_width; ++k) {
+ RCHECK(is >> label_buf[k])
+ << "Invalid ImageList, did you provide the correct label_width?";
+ }
+ if (pack_label) rec.header.flag = label_width;
+ rec.SaveHeader(&blob);
+ if (pack_label) {
+ size_t bsize = blob.size();
+ blob.resize(bsize + label_buf.size()*sizeof(float));
+ memcpy(BeginPtr(blob) + bsize,
+ BeginPtr(label_buf), label_buf.size()*sizeof(float));
+ }
+ RCHECK(std::getline(is, fname));
+ // eliminate invalid chars in the end
+ while (fname.length() != 0 &&
+ (isspace(*fname.rbegin()) || !isprint(*fname.rbegin()))) {
+ fname.resize(fname.length() - 1);
+ }
+ // eliminate invalid chars in beginning.
+ const char *p = fname.c_str();
+ while (isspace(*p)) ++p;
+ path = root + p;
+ // use "r" is equal to rb in dmlc::Stream
+ dmlc::Stream *fi = dmlc::Stream::Create(path.c_str(), "r");
+ decode_buf.clear();
+ size_t imsize = 0;
+ while (true) {
+ decode_buf.resize(imsize + kBufferSize);
+ size_t nread = fi->Read(BeginPtr(decode_buf) + imsize, kBufferSize);
+ imsize += nread;
+ decode_buf.resize(imsize);
+ if (nread != kBufferSize) break;
+ }
+ delete fi;
+
+
+ if (unchanged != 1) {
+ cv::Mat img = cv::imdecode(decode_buf, color_mode);
+ RCHECK(img.data != NULL) << "OpenCV decode fail:" << path;
+ cv::Mat res = img;
+ if (new_size > 0) {
+ if (center_crop) {
+ if (img.rows > img.cols) {
+ int margin = (img.rows - img.cols)/2;
+ img = img(cv::Range(margin, margin+img.cols), cv::Range(0, img.cols));
+ } else {
+ int margin = (img.cols - img.rows)/2;
+ img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows));
+ }
+ }
+ int interpolation_method = 1;
+ if (img.rows > img.cols) {
+ if (img.cols != new_size) {
+ interpolation_method = GetInterMethod(inter_method, img.cols, img.rows,
+ new_size,
+ img.rows * new_size / img.cols, prnd);
+ cv::resize(img, res, cv::Size(new_size,
+ img.rows * new_size / img.cols),
+ 0, 0, interpolation_method);
+ } else {
+ res = img.clone();
+ }
+ } else {
+ if (img.rows != new_size) {
+ interpolation_method = GetInterMethod(inter_method, img.cols,
+ img.rows, new_size * img.cols / img.rows,
+ new_size, prnd);
+ cv::resize(img, res, cv::Size(new_size * img.cols / img.rows,
+ new_size), 0, 0, interpolation_method);
+ } else {
+ res = img.clone();
+ }
+ }
+ }
+ encode_buf.clear();
+ RCHECK(cv::imencode(encoding, res, encode_buf, encode_params));
+
+ // write buffer
+ size_t bsize = blob.size();
+ blob.resize(bsize + encode_buf.size());
+ memcpy(BeginPtr(blob) + bsize,
+ BeginPtr(encode_buf), encode_buf.size());
+ } else {
+ size_t bsize = blob.size();
+ blob.resize(bsize + decode_buf.size());
+ memcpy(BeginPtr(blob) + bsize,
+ BeginPtr(decode_buf), decode_buf.size());
+ }
+ writer.WriteRecord(BeginPtr(blob), blob.size());
+ // write header
+ ++imcnt;
+ if (imcnt % 1000 == 0) {
+ LOG(INFO) << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
+ }
+ }
+ LOG(INFO) << "Total: " << imcnt << " images processed, " << GetTime() - tstart << " sec elapsed";
+ delete fo;
+ delete flist;
+}
+} // namespace R
+} // namespace mxnet
diff --git a/R-package/src/im2rec.h b/R-package/src/im2rec.h
new file mode 100644
index 0000000..a98a733
--- /dev/null
+++ b/R-package/src/im2rec.h
@@ -0,0 +1,42 @@
+/*!
+ * Copyright (c) 2017 by Contributors
+ * \file export.h
+ * \brief Export module that takes charge of code generation and document
+ * Generation for functions exported from R-side
+ */
+
+#ifndef MXNET_RCPP_IM2REC_H_
+#define MXNET_RCPP_IM2REC_H_
+
+#include <Rcpp.h>
+#include <string>
+
+namespace mxnet {
+namespace R {
+
+class IM2REC {
+ public:
+ /*!
+ * \brief Export the generated file into path.
+ * \param path The path to be exported.
+ */
+ static void im2rec(const std::string & image_lst, const std::string & root,
+ const std::string & output_rec,
+ int label_width = 1, int pack_label = 0, int new_size = -1, int nsplit = 1,
+ int partid = 0, int center_crop = 0, int quality = 95,
+ int color_mode = 1, int unchanged = 0,
+ int inter_method = 1, std::string encoding = ".jpg");
+ // intialize the Rcpp module
+ static void InitRcppModule();
+
+ public:
+ // get the singleton of exporter
+ static IM2REC* Get();
+ /*! \brief The scope of current module to export */
+ Rcpp::Module* scope_;
+};
+
+} // namespace R
+} // namespace mxnet
+
+#endif // MXNET_RCPP_IM2REC_H_
diff --git a/R-package/src/mxnet.cc b/R-package/src/mxnet.cc
index 9d16190..9f8239b 100644
--- a/R-package/src/mxnet.cc
+++ b/R-package/src/mxnet.cc
@@ -12,6 +12,7 @@
#include "./io.h"
#include "./kvstore.h"
#include "./export.h"
+#include "./im2rec.h"
namespace mxnet {
namespace R {
@@ -56,4 +57,6 @@ RCPP_MODULE(mxnet) {
DataIterCreateFunction::InitRcppModule();
KVStore::InitRcppModule();
Exporter::InitRcppModule();
+ IM2REC::InitRcppModule();
}
+
diff --git a/R-package/vignettes/CatsDogsFinetune.Rmd b/R-package/vignettes/CatsDogsFinetune.Rmd
index e30b513..95f90be 100644
--- a/R-package/vignettes/CatsDogsFinetune.Rmd
+++ b/R-package/vignettes/CatsDogsFinetune.Rmd
@@ -104,12 +104,30 @@ Map(function(x, y) {
}, x = files, y = new_names)
```
-### Creating .rec files using im2rec.py
-
-```{bash, eval = FALSE}
-python im2rec.py --list=1 --recursive=1 --train-ratio=0.8 cats_dogs train_pad_224x224
-python im2rec.py --num-thread=4 --pass-through=1 cats_dogs_train.lst train_pad_224x224
-python im2rec.py --num-thread=4 --pass-through=1 cats_dogs_val.lst train_pad_224x224
+### Creating .rec files
+
+```{r, eval = FALSE}
+cat_files <- list.files("train_pad_224x224/cat/", recursive=TRUE)
+cat_files <- paste0("cat/", cat_files)
+
+dog_files <- list.files("train_pad_224x224/dog/", recursive=TRUE)
+dog_files <- paste0("dog/", dog_files)
+
+train_ind <- sample(length(cat_files), length(cat_files) * 0.8)
+train_data <- c(1:(length(train_ind) * 2))
+train_data <- cbind(train_data, c(rep(0, length(train_ind)), rep(1, length(train_ind))))
+train_data <- cbind(train_data, c(cat_files[train_ind], dog_files[train_ind]))
+train_data <- train_data[sample(nrow(train_data)),]
+write.table(train_data, "cats_dogs_train.lst", quote = FALSE, sep = "\t", row.names = FALSE, col.names = FALSE)
+im2rec("cats_dogs_train.lst", "train_pad_224x224/", "cats_dogs_train.rec")
+
+val_ind <- c(1:length(cat_files))[!c(1:length(cat_files)) %in% train_ind]
+val_data <- c(1:(length(val_ind) * 2))
+val_data <- cbind(val_data, c(rep(0, length(val_ind)), rep(1, length(val_ind))))
+val_data <- cbind(val_data, c(cat_files[val_ind], dog_files[val_ind]))
+val_data <- val_data[sample(nrow(val_data)),]
+write.table(val_data, "cats_dogs_val.lst", quote = FALSE, sep = "\t", row.names = FALSE, col.names = FALSE)
+im2rec("cats_dogs_val.lst", "train_pad_224x224/", "cats_dogs_val.rec")
```
## The data iterator
@@ -215,21 +233,20 @@ preprocImage<- function(src, # URL or file location
num_channels = 3, # 3 for RGB, 1 for grayscale
mult_by = 1, # set to 255 for normalized image
crop = FALSE) { # no crop by default
-
- im <- load.image(src)
-
- if (crop) {
- shape <- dim(im)
+ im <- load.image(src)
+
+ if (crop) {
+ shape <- dim(im)
short_edge <- min(shape[1:2])
- xx <- floor((shape[1] - short_edge) / 2)
- yy <- floor((shape[2] - short_edge) / 2)
+ xx <- floor((shape[1] - short_edge) / 2)
+ yy <- floor((shape[2] - short_edge) / 2)
im <- crop.borders(im, xx, yy)
- }
-
- resized <- resize(im, size_x = width, size_y = height)
- arr <- as.array(resized) * mult_by
- dim(arr) <- c(width, height, num_channels, 1)
- return(arr)
+ }
+
+ resized <- resize(im, size_x = width, size_y = height)
+ arr <- as.array(resized) * mult_by
+ dim(arr) <- c(width, height, num_channels, 1)
+ return(arr)
}
```
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].