You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/01/28 01:08:45 UTC

[4/5] spark git commit: [SPARK-5097][SQL] DataFrame

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 1990323..7d7550c 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,15 +20,19 @@ public classes of Spark SQL:
 
     - L{SQLContext}
       Main entry point for SQL functionality.
-    - L{SchemaRDD}
+    - L{DataFrame}
       A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
-      addition to normal RDD operations, SchemaRDDs also support SQL.
+      addition to normal RDD operations, DataFrames also support SQL.
+    - L{GroupedDataFrame}
+    - L{Column}
+      Column is a DataFrame with a single column.
     - L{Row}
       A Row of data returned by a Spark SQL query.
     - L{HiveContext}
       Main entry point for accessing data stored in Apache Hive..
 """
 
+import sys
 import itertools
 import decimal
 import datetime
@@ -36,6 +40,9 @@ import keyword
 import warnings
 import json
 import re
+import random
+import os
+from tempfile import NamedTemporaryFile
 from array import array
 from operator import itemgetter
 from itertools import imap
@@ -43,6 +50,7 @@ from itertools import imap
 from py4j.protocol import Py4JError
 from py4j.java_collections import ListConverter, MapConverter
 
+from pyspark.context import SparkContext
 from pyspark.rdd import RDD
 from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
     CloudPickleSerializer, UTF8Deserializer
@@ -54,7 +62,8 @@ __all__ = [
     "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
     "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
     "ShortType", "ArrayType", "MapType", "StructField", "StructType",
-    "SQLContext", "HiveContext", "SchemaRDD", "Row"]
+    "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row",
+    "SchemaRDD"]
 
 
 class DataType(object):
@@ -1171,7 +1180,7 @@ def _create_cls(dataType):
 
     class Row(tuple):
 
-        """ Row in SchemaRDD """
+        """ Row in DataFrame """
         __DATATYPE__ = dataType
         __FIELDS__ = tuple(f.name for f in dataType.fields)
         __slots__ = ()
@@ -1198,7 +1207,7 @@ class SQLContext(object):
 
     """Main entry point for Spark SQL functionality.
 
-    A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as
+    A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
     tables, execute SQL over tables, cache tables, and read parquet files.
     """
 
@@ -1209,8 +1218,8 @@ class SQLContext(object):
         :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
         SQLContext in the JVM, instead we make all calls to this object.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
         Traceback (most recent call last):
             ...
         TypeError:...
@@ -1225,12 +1234,12 @@ class SQLContext(object):
         >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
         ...     b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
         ...     time=datetime(2014, 8, 1, 14, 1, 5))])
-        >>> srdd = sqlCtx.inferSchema(allTypes)
-        >>> srdd.registerTempTable("allTypes")
+        >>> df = sqlCtx.inferSchema(allTypes)
+        >>> df.registerTempTable("allTypes")
         >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
         ...            'from allTypes where b and i > 0').collect()
         [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
-        >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+        >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
         ...                     x.row.a, x.list)).collect()
         [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
         """
@@ -1309,23 +1318,23 @@ class SQLContext(object):
         ...     [Row(field1=1, field2="row1"),
         ...      Row(field1=2, field2="row2"),
         ...      Row(field1=3, field2="row3")])
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.collect()[0]
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.collect()[0]
         Row(field1=1, field2=u'row1')
 
         >>> NestedRow = Row("f1", "f2")
         >>> nestedRdd1 = sc.parallelize([
         ...     NestedRow(array('i', [1, 2]), {"row1": 1.0}),
         ...     NestedRow(array('i', [2, 3]), {"row2": 2.0})])
-        >>> srdd = sqlCtx.inferSchema(nestedRdd1)
-        >>> srdd.collect()
+        >>> df = sqlCtx.inferSchema(nestedRdd1)
+        >>> df.collect()
         [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
 
         >>> nestedRdd2 = sc.parallelize([
         ...     NestedRow([[1, 2], [2, 3]], [1, 2]),
         ...     NestedRow([[2, 3], [3, 4]], [2, 3])])
-        >>> srdd = sqlCtx.inferSchema(nestedRdd2)
-        >>> srdd.collect()
+        >>> df = sqlCtx.inferSchema(nestedRdd2)
+        >>> df.collect()
         [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
 
         >>> from collections import namedtuple
@@ -1334,13 +1343,13 @@ class SQLContext(object):
         ...     [CustomRow(field1=1, field2="row1"),
         ...      CustomRow(field1=2, field2="row2"),
         ...      CustomRow(field1=3, field2="row3")])
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.collect()[0]
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.collect()[0]
         Row(field1=1, field2=u'row1')
         """
 
-        if isinstance(rdd, SchemaRDD):
-            raise TypeError("Cannot apply schema to SchemaRDD")
+        if isinstance(rdd, DataFrame):
+            raise TypeError("Cannot apply schema to DataFrame")
 
         first = rdd.first()
         if not first:
@@ -1384,10 +1393,10 @@ class SQLContext(object):
         >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
         >>> schema = StructType([StructField("field1", IntegerType(), False),
         ...     StructField("field2", StringType(), False)])
-        >>> srdd = sqlCtx.applySchema(rdd2, schema)
-        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
-        >>> srdd2 = sqlCtx.sql("SELECT * from table1")
-        >>> srdd2.collect()
+        >>> df = sqlCtx.applySchema(rdd2, schema)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.sql("SELECT * from table1")
+        >>> df2.collect()
         [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
 
         >>> from datetime import date, datetime
@@ -1410,15 +1419,15 @@ class SQLContext(object):
         ...         StructType([StructField("b", ShortType(), False)]), False),
         ...     StructField("list", ArrayType(ByteType(), False), False),
         ...     StructField("null", DoubleType(), True)])
-        >>> srdd = sqlCtx.applySchema(rdd, schema)
-        >>> results = srdd.map(
+        >>> df = sqlCtx.applySchema(rdd, schema)
+        >>> results = df.map(
         ...     lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
         ...         x.time, x.map["a"], x.struct.b, x.list, x.null))
         >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
         (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
              datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
 
-        >>> srdd.registerTempTable("table2")
+        >>> df.registerTempTable("table2")
         >>> sqlCtx.sql(
         ...   "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
         ...     "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
@@ -1431,13 +1440,13 @@ class SQLContext(object):
         >>> abstract = "byte short float time map{} struct(b) list[]"
         >>> schema = _parse_schema_abstract(abstract)
         >>> typedSchema = _infer_schema_type(rdd.first(), schema)
-        >>> srdd = sqlCtx.applySchema(rdd, typedSchema)
-        >>> srdd.collect()
+        >>> df = sqlCtx.applySchema(rdd, typedSchema)
+        >>> df.collect()
         [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
         """
 
-        if isinstance(rdd, SchemaRDD):
-            raise TypeError("Cannot apply schema to SchemaRDD")
+        if isinstance(rdd, DataFrame):
+            raise TypeError("Cannot apply schema to DataFrame")
 
         if not isinstance(schema, StructType):
             raise TypeError("schema should be StructType")
@@ -1457,8 +1466,8 @@ class SQLContext(object):
         rdd = rdd.map(converter)
 
         jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
-        srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
-        return SchemaRDD(srdd, self)
+        df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+        return DataFrame(df, self)
 
     def registerRDDAsTable(self, rdd, tableName):
         """Registers the given RDD as a temporary table in the catalog.
@@ -1466,34 +1475,34 @@ class SQLContext(object):
         Temporary tables exist only during the lifetime of this instance of
         SQLContext.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
         """
-        if (rdd.__class__ is SchemaRDD):
-            srdd = rdd._jschema_rdd.baseSchemaRDD()
-            self._ssql_ctx.registerRDDAsTable(srdd, tableName)
+        if (rdd.__class__ is DataFrame):
+            df = rdd._jdf
+            self._ssql_ctx.registerRDDAsTable(df, tableName)
         else:
-            raise ValueError("Can only register SchemaRDD as table")
+            raise ValueError("Can only register DataFrame as table")
 
     def parquetFile(self, path):
-        """Loads a Parquet file, returning the result as a L{SchemaRDD}.
+        """Loads a Parquet file, returning the result as a L{DataFrame}.
 
         >>> import tempfile, shutil
         >>> parquetFile = tempfile.mkdtemp()
         >>> shutil.rmtree(parquetFile)
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.saveAsParquetFile(parquetFile)
-        >>> srdd2 = sqlCtx.parquetFile(parquetFile)
-        >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.saveAsParquetFile(parquetFile)
+        >>> df2 = sqlCtx.parquetFile(parquetFile)
+        >>> sorted(df.collect()) == sorted(df2.collect())
         True
         """
-        jschema_rdd = self._ssql_ctx.parquetFile(path)
-        return SchemaRDD(jschema_rdd, self)
+        jdf = self._ssql_ctx.parquetFile(path)
+        return DataFrame(jdf, self)
 
     def jsonFile(self, path, schema=None, samplingRatio=1.0):
         """
         Loads a text file storing one JSON object per line as a
-        L{SchemaRDD}.
+        L{DataFrame}.
 
         If the schema is provided, applies the given schema to this
         JSON dataset.
@@ -1508,23 +1517,23 @@ class SQLContext(object):
         >>> for json in jsonStrings:
         ...   print>>ofn, json
         >>> ofn.close()
-        >>> srdd1 = sqlCtx.jsonFile(jsonFile)
-        >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
-        >>> srdd2 = sqlCtx.sql(
+        >>> df1 = sqlCtx.jsonFile(jsonFile)
+        >>> sqlCtx.registerRDDAsTable(df1, "table1")
+        >>> df2 = sqlCtx.sql(
         ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
         ...   "field6 as f4 from table1")
-        >>> for r in srdd2.collect():
+        >>> for r in df2.collect():
         ...     print r
         Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
         Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
         Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
 
-        >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
-        >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
-        >>> srdd4 = sqlCtx.sql(
+        >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+        >>> sqlCtx.registerRDDAsTable(df3, "table2")
+        >>> df4 = sqlCtx.sql(
         ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
         ...   "field6 as f4 from table2")
-        >>> for r in srdd4.collect():
+        >>> for r in df4.collect():
         ...    print r
         Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
         Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
@@ -1536,23 +1545,23 @@ class SQLContext(object):
         ...         StructType([
         ...             StructField("field5",
         ...                 ArrayType(IntegerType(), False), True)]), False)])
-        >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
-        >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
-        >>> srdd6 = sqlCtx.sql(
+        >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+        >>> sqlCtx.registerRDDAsTable(df5, "table3")
+        >>> df6 = sqlCtx.sql(
         ...   "SELECT field2 AS f1, field3.field5 as f2, "
         ...   "field3.field5[0] as f3 from table3")
-        >>> srdd6.collect()
+        >>> df6.collect()
         [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
         """
         if schema is None:
-            srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
+            df = self._ssql_ctx.jsonFile(path, samplingRatio)
         else:
             scala_datatype = self._ssql_ctx.parseDataType(schema.json())
-            srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
-        return SchemaRDD(srdd, self)
+            df = self._ssql_ctx.jsonFile(path, scala_datatype)
+        return DataFrame(df, self)
 
     def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
-        """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
+        """Loads an RDD storing one JSON object per string as a L{DataFrame}.
 
         If the schema is provided, applies the given schema to this
         JSON dataset.
@@ -1560,23 +1569,23 @@ class SQLContext(object):
         Otherwise, it samples the dataset with ratio `samplingRatio` to
         determine the schema.
 
-        >>> srdd1 = sqlCtx.jsonRDD(json)
-        >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
-        >>> srdd2 = sqlCtx.sql(
+        >>> df1 = sqlCtx.jsonRDD(json)
+        >>> sqlCtx.registerRDDAsTable(df1, "table1")
+        >>> df2 = sqlCtx.sql(
         ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
         ...   "field6 as f4 from table1")
-        >>> for r in srdd2.collect():
+        >>> for r in df2.collect():
         ...     print r
         Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
         Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
         Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
 
-        >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
-        >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
-        >>> srdd4 = sqlCtx.sql(
+        >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+        >>> sqlCtx.registerRDDAsTable(df3, "table2")
+        >>> df4 = sqlCtx.sql(
         ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, "
         ...   "field6 as f4 from table2")
-        >>> for r in srdd4.collect():
+        >>> for r in df4.collect():
         ...     print r
         Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
         Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
@@ -1588,12 +1597,12 @@ class SQLContext(object):
         ...         StructType([
         ...             StructField("field5",
         ...                 ArrayType(IntegerType(), False), True)]), False)])
-        >>> srdd5 = sqlCtx.jsonRDD(json, schema)
-        >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
-        >>> srdd6 = sqlCtx.sql(
+        >>> df5 = sqlCtx.jsonRDD(json, schema)
+        >>> sqlCtx.registerRDDAsTable(df5, "table3")
+        >>> df6 = sqlCtx.sql(
         ...   "SELECT field2 AS f1, field3.field5 as f2, "
         ...   "field3.field5[0] as f3 from table3")
-        >>> srdd6.collect()
+        >>> df6.collect()
         [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
 
         >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
@@ -1615,33 +1624,33 @@ class SQLContext(object):
         keyed._bypass_serializer = True
         jrdd = keyed._jrdd.map(self._jvm.BytesToString())
         if schema is None:
-            srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
+            df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
         else:
             scala_datatype = self._ssql_ctx.parseDataType(schema.json())
-            srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
-        return SchemaRDD(srdd, self)
+            df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+        return DataFrame(df, self)
 
     def sql(self, sqlQuery):
-        """Return a L{SchemaRDD} representing the result of the given query.
+        """Return a L{DataFrame} representing the result of the given query.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
-        >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
-        >>> srdd2.collect()
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+        >>> df2.collect()
         [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
         """
-        return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self)
+        return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
 
     def table(self, tableName):
-        """Returns the specified table as a L{SchemaRDD}.
+        """Returns the specified table as a L{DataFrame}.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> sqlCtx.registerRDDAsTable(srdd, "table1")
-        >>> srdd2 = sqlCtx.table("table1")
-        >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> sqlCtx.registerRDDAsTable(df, "table1")
+        >>> df2 = sqlCtx.table("table1")
+        >>> sorted(df.collect()) == sorted(df2.collect())
         True
         """
-        return SchemaRDD(self._ssql_ctx.table(tableName), self)
+        return DataFrame(self._ssql_ctx.table(tableName), self)
 
     def cacheTable(self, tableName):
         """Caches the specified table in-memory."""
@@ -1707,7 +1716,7 @@ def _create_row(fields, values):
 class Row(tuple):
 
     """
-    A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
+    A row in L{DataFrame}. The fields in it can be accessed like attributes.
 
     Row can be used to create a row object by using named arguments,
     the fields will be sorted by names.
@@ -1799,111 +1808,119 @@ def inherit_doc(cls):
     return cls
 
 
-@inherit_doc
-class SchemaRDD(RDD):
+class DataFrame(object):
 
-    """An RDD of L{Row} objects that has an associated schema.
+    """A collection of rows that have the same columns.
 
-    The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
-    utilize the relational query api exposed by Spark SQL.
+    A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+    and can be created using various functions in :class:`SQLContext`::
 
-    For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
-    L{SchemaRDD} is not operated on directly, as it's underlying
-    implementation is an RDD composed of Java objects. Instead it is
-    converted to a PythonRDD in the JVM, on which Python operations can
-    be done.
+        people = sqlContext.parquetFile("...")
 
-    This class receives raw tuples from Java but assigns a class to it in
-    all its data-collection methods (mapPartitionsWithIndex, collect, take,
-    etc) so that PySpark sees them as Row objects with named fields.
+    Once created, it can be manipulated using the various domain-specific-language
+    (DSL) functions defined in: [[DataFrame]], [[Column]].
+
+    To select a column from the data frame, use the apply method::
+
+        ageCol = people.age
+
+    Note that the :class:`Column` type can also be manipulated
+    through its various functions::
+
+        # The following creates a new column that increases everybody's age by 10.
+        people.age + 10
+
+
+    A more concrete example::
+
+        # To create DataFrame using SQLContext
+        people = sqlContext.parquetFile("...")
+        department = sqlContext.parquetFile("...")
+
+        people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+          .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
     """
 
-    def __init__(self, jschema_rdd, sql_ctx):
+    def __init__(self, jdf, sql_ctx):
+        self._jdf = jdf
         self.sql_ctx = sql_ctx
-        self._sc = sql_ctx._sc
-        clsName = jschema_rdd.getClass().getName()
-        assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD"
-        self._jschema_rdd = jschema_rdd
-        self._id = None
+        self._sc = sql_ctx and sql_ctx._sc
         self.is_cached = False
-        self.is_checkpointed = False
-        self.ctx = self.sql_ctx._sc
-        # the _jrdd is created by javaToPython(), serialized by pickle
-        self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer())
 
     @property
-    def _jrdd(self):
-        """Lazy evaluation of PythonRDD object.
+    def rdd(self):
+        """Return the content of the :class:`DataFrame` as an :class:`RDD`
+        of :class:`Row`s. """
+        if not hasattr(self, '_lazy_rdd'):
+            jrdd = self._jdf.javaToPython()
+            rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+            schema = self.schema()
 
-        Only done when a user calls methods defined by the
-        L{pyspark.rdd.RDD} super class (map, filter, etc.).
-        """
-        if not hasattr(self, '_lazy_jrdd'):
-            self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
-        return self._lazy_jrdd
+            def applySchema(it):
+                cls = _create_cls(schema)
+                return itertools.imap(cls, it)
 
-    def id(self):
-        if self._id is None:
-            self._id = self._jrdd.id()
-        return self._id
+            self._lazy_rdd = rdd.mapPartitions(applySchema)
+
+        return self._lazy_rdd
 
     def limit(self, num):
         """Limit the result count to the number specified.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.limit(2).collect()
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.limit(2).collect()
         [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
-        >>> srdd.limit(0).collect()
+        >>> df.limit(0).collect()
         []
         """
-        rdd = self._jschema_rdd.baseSchemaRDD().limit(num)
-        return SchemaRDD(rdd, self.sql_ctx)
+        jdf = self._jdf.limit(num)
+        return DataFrame(jdf, self.sql_ctx)
 
     def toJSON(self, use_unicode=False):
-        """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row.
+        """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
 
-        >>> srdd1 = sqlCtx.jsonRDD(json)
-        >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
-        >>> srdd2 = sqlCtx.sql( "SELECT * from table1")
-        >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+        >>> df1 = sqlCtx.jsonRDD(json)
+        >>> sqlCtx.registerRDDAsTable(df1, "table1")
+        >>> df2 = sqlCtx.sql( "SELECT * from table1")
+        >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
         True
-        >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1")
-        >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+        >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+        >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
         True
         """
-        rdd = self._jschema_rdd.baseSchemaRDD().toJSON()
+        rdd = self._jdf.toJSON()
         return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
 
     def saveAsParquetFile(self, path):
         """Save the contents as a Parquet file, preserving the schema.
 
         Files that are written out using this method can be read back in as
-        a SchemaRDD using the L{SQLContext.parquetFile} method.
+        a DataFrame using the L{SQLContext.parquetFile} method.
 
         >>> import tempfile, shutil
         >>> parquetFile = tempfile.mkdtemp()
         >>> shutil.rmtree(parquetFile)
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.saveAsParquetFile(parquetFile)
-        >>> srdd2 = sqlCtx.parquetFile(parquetFile)
-        >>> sorted(srdd2.collect()) == sorted(srdd.collect())
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.saveAsParquetFile(parquetFile)
+        >>> df2 = sqlCtx.parquetFile(parquetFile)
+        >>> sorted(df2.collect()) == sorted(df.collect())
         True
         """
-        self._jschema_rdd.saveAsParquetFile(path)
+        self._jdf.saveAsParquetFile(path)
 
     def registerTempTable(self, name):
         """Registers this RDD as a temporary table using the given name.
 
         The lifetime of this temporary table is tied to the L{SQLContext}
-        that was used to create this SchemaRDD.
+        that was used to create this DataFrame.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.registerTempTable("test")
-        >>> srdd2 = sqlCtx.sql("select * from test")
-        >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.registerTempTable("test")
+        >>> df2 = sqlCtx.sql("select * from test")
+        >>> sorted(df.collect()) == sorted(df2.collect())
         True
         """
-        self._jschema_rdd.registerTempTable(name)
+        self._jdf.registerTempTable(name)
 
     def registerAsTable(self, name):
         """DEPRECATED: use registerTempTable() instead"""
@@ -1911,62 +1928,61 @@ class SchemaRDD(RDD):
         self.registerTempTable(name)
 
     def insertInto(self, tableName, overwrite=False):
-        """Inserts the contents of this SchemaRDD into the specified table.
+        """Inserts the contents of this DataFrame into the specified table.
 
         Optionally overwriting any existing data.
         """
-        self._jschema_rdd.insertInto(tableName, overwrite)
+        self._jdf.insertInto(tableName, overwrite)
 
     def saveAsTable(self, tableName):
-        """Creates a new table with the contents of this SchemaRDD."""
-        self._jschema_rdd.saveAsTable(tableName)
+        """Creates a new table with the contents of this DataFrame."""
+        self._jdf.saveAsTable(tableName)
 
     def schema(self):
-        """Returns the schema of this SchemaRDD (represented by
+        """Returns the schema of this DataFrame (represented by
         a L{StructType})."""
-        return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
-
-    def schemaString(self):
-        """Returns the output schema in the tree format."""
-        return self._jschema_rdd.schemaString()
+        return _parse_datatype_json_string(self._jdf.schema().json())
 
     def printSchema(self):
         """Prints out the schema in the tree format."""
-        print self.schemaString()
+        print (self._jdf.schema().treeString())
 
     def count(self):
         """Return the number of elements in this RDD.
 
         Unlike the base RDD implementation of count, this implementation
-        leverages the query optimizer to compute the count on the SchemaRDD,
+        leverages the query optimizer to compute the count on the DataFrame,
         which supports features such as filter pushdown.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.count()
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.count()
         3L
-        >>> srdd.count() == srdd.map(lambda x: x).count()
+        >>> df.count() == df.map(lambda x: x).count()
         True
         """
-        return self._jschema_rdd.count()
+        return self._jdf.count()
 
     def collect(self):
-        """Return a list that contains all of the rows in this RDD.
+        """Return a list that contains all of the rows.
 
         Each object in the list is a Row, the fields can be accessed as
         attributes.
 
-        Unlike the base RDD implementation of collect, this implementation
-        leverages the query optimizer to perform a collect on the SchemaRDD,
-        which supports features such as filter pushdown.
-
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.collect()
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.collect()
         [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
         """
-        with SCCallSiteSync(self.context) as css:
-            bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
+        with SCCallSiteSync(self._sc) as css:
+            bytesInJava = self._jdf.javaToPython().collect().iterator()
         cls = _create_cls(self.schema())
-        return map(cls, self._collect_iterator_through_file(bytesInJava))
+        tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+        tempFile.close()
+        self._sc._writeToFile(bytesInJava, tempFile.name)
+        # Read the data into Python and deserialize it:
+        with open(tempFile.name, 'rb') as tempFile:
+            rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+        os.unlink(tempFile.name)
+        return [cls(r) for r in rs]
 
     def take(self, num):
         """Take the first num rows of the RDD.
@@ -1974,130 +1990,555 @@ class SchemaRDD(RDD):
         Each object in the list is a Row, the fields can be accessed as
         attributes.
 
-        Unlike the base RDD implementation of take, this implementation
-        leverages the query optimizer to perform a collect on a SchemaRDD,
-        which supports features such as filter pushdown.
-
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.take(2)
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.take(2)
         [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
         """
         return self.limit(num).collect()
 
-    # Convert each object in the RDD to a Row with the right class
-    # for this SchemaRDD, so that fields can be accessed as attributes.
-    def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+    def map(self, f):
+        """ Return a new RDD by applying a function to each Row, it's a
+        shorthand for df.rdd.map()
         """
-        Return a new RDD by applying a function to each partition of this RDD,
-        while tracking the index of the original partition.
+        return self.rdd.map(f)
 
-        >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
-        >>> def f(splitIndex, iterator): yield splitIndex
-        >>> rdd.mapPartitionsWithIndex(f).sum()
-        6
+    def mapPartitions(self, f, preservesPartitioning=False):
         """
-        rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
-
-        schema = self.schema()
+        Return a new RDD by applying a function to each partition.
 
-        def applySchema(_, it):
-            cls = _create_cls(schema)
-            return itertools.imap(cls, it)
-
-        objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning)
-        return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
+        >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+        >>> def f(iterator): yield 1
+        >>> rdd.mapPartitions(f).sum()
+        4
+        """
+        return self.rdd.mapPartitions(f, preservesPartitioning)
 
-    # We override the default cache/persist/checkpoint behavior
-    # as we want to cache the underlying SchemaRDD object in the JVM,
-    # not the PythonRDD checkpointed by the super class
     def cache(self):
+        """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+        """
         self.is_cached = True
-        self._jschema_rdd.cache()
+        self._jdf.cache()
         return self
 
     def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+        """ Set the storage level to persist its values across operations
+        after the first time it is computed. This can only be used to assign
+        a new storage level if the RDD does not have a storage level set yet.
+        If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+        """
         self.is_cached = True
-        javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
-        self._jschema_rdd.persist(javaStorageLevel)
+        javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+        self._jdf.persist(javaStorageLevel)
         return self
 
     def unpersist(self, blocking=True):
+        """ Mark it as non-persistent, and remove all blocks for it from
+        memory and disk.
+        """
         self.is_cached = False
-        self._jschema_rdd.unpersist(blocking)
+        self._jdf.unpersist(blocking)
         return self
 
-    def checkpoint(self):
-        self.is_checkpointed = True
-        self._jschema_rdd.checkpoint()
+    # def coalesce(self, numPartitions, shuffle=False):
+    #     rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+    #     return DataFrame(rdd, self.sql_ctx)
 
-    def isCheckpointed(self):
-        return self._jschema_rdd.isCheckpointed()
+    def repartition(self, numPartitions):
+        """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+        partitions.
+        """
+        rdd = self._jdf.repartition(numPartitions, None)
+        return DataFrame(rdd, self.sql_ctx)
 
-    def getCheckpointFile(self):
-        checkpointFile = self._jschema_rdd.getCheckpointFile()
-        if checkpointFile.isDefined():
-            return checkpointFile.get()
+    def sample(self, withReplacement, fraction, seed=None):
+        """
+        Return a sampled subset of this DataFrame.
 
-    def coalesce(self, numPartitions, shuffle=False):
-        rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None)
-        return SchemaRDD(rdd, self.sql_ctx)
+        >>> df = sqlCtx.inferSchema(rdd)
+        >>> df.sample(False, 0.5, 97).count()
+        2L
+        """
+        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+        seed = seed if seed is not None else random.randint(0, sys.maxint)
+        rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+        return DataFrame(rdd, self.sql_ctx)
+
+    # def takeSample(self, withReplacement, num, seed=None):
+    #     """Return a fixed-size sampled subset of this DataFrame.
+    #
+    #     >>> df = sqlCtx.inferSchema(rdd)
+    #     >>> df.takeSample(False, 2, 97)
+    #     [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+    #     """
+    #     seed = seed if seed is not None else random.randint(0, sys.maxint)
+    #     with SCCallSiteSync(self.context) as css:
+    #         bytesInJava = self._jdf \
+    #             .takeSampleToPython(withReplacement, num, long(seed)) \
+    #             .iterator()
+    #     cls = _create_cls(self.schema())
+    #     return map(cls, self._collect_iterator_through_file(bytesInJava))
 
-    def distinct(self, numPartitions=None):
-        if numPartitions is None:
-            rdd = self._jschema_rdd.distinct()
+    @property
+    def dtypes(self):
+        """Return all column names and their data types as a list.
+        """
+        return [(f.name, str(f.dataType)) for f in self.schema().fields]
+
+    @property
+    def columns(self):
+        """ Return all column names as a list.
+        """
+        return [f.name for f in self.schema().fields]
+
+    def show(self):
+        raise NotImplemented
+
+    def join(self, other, joinExprs=None, joinType=None):
+        """
+        Join with another DataFrame, using the given join expression.
+        The following performs a full outer join between `df1` and `df2`::
+
+            df1.join(df2, df1.key == df2.key, "outer")
+
+        :param other: Right side of the join
+        :param joinExprs: Join expression
+        :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`,
+                         `semijoin`.
+        """
+        if joinType is None:
+            if joinExprs is None:
+                jdf = self._jdf.join(other._jdf)
+            else:
+                jdf = self._jdf.join(other._jdf, joinExprs)
         else:
-            rdd = self._jschema_rdd.distinct(numPartitions, None)
-        return SchemaRDD(rdd, self.sql_ctx)
+            jdf = self._jdf.join(other._jdf, joinExprs, joinType)
+        return DataFrame(jdf, self.sql_ctx)
+
+    def sort(self, *cols):
+        """ Return a new [[DataFrame]] sorted by the specified column,
+        in ascending column.
+
+        :param cols: The columns or expressions used for sorting
+        """
+        if not cols:
+            raise ValueError("should sort by at least one column")
+        for i, c in enumerate(cols):
+            if isinstance(c, basestring):
+                cols[i] = Column(c)
+        jcols = [c._jc for c in cols]
+        jdf = self._jdf.join(*jcols)
+        return DataFrame(jdf, self.sql_ctx)
+
+    sortBy = sort
+
+    def head(self, n=None):
+        """ Return the first `n` rows or the first row if n is None. """
+        if n is None:
+            rs = self.head(1)
+            return rs[0] if rs else None
+        return self.take(n)
+
+    def tail(self):
+        raise NotImplemented
+
+    def __getitem__(self, item):
+        if isinstance(item, basestring):
+            return Column(self._jdf.apply(item))
+
+        # TODO projection
+        raise IndexError
+
+    def __getattr__(self, name):
+        """ Return the column by given name """
+        if isinstance(name, basestring):
+            return Column(self._jdf.apply(name))
+        raise AttributeError
+
+    def As(self, name):
+        """ Alias the current DataFrame """
+        return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx)
+
+    def select(self, *cols):
+        """ Selecting a set of expressions.::
+
+            df.select()
+            df.select('colA', 'colB')
+            df.select(df.colA, df.colB + 1)
 
-    def intersection(self, other):
-        if (other.__class__ is SchemaRDD):
-            rdd = self._jschema_rdd.intersection(other._jschema_rdd)
-            return SchemaRDD(rdd, self.sql_ctx)
+        """
+        if not cols:
+            cols = ["*"]
+        if isinstance(cols[0], basestring):
+            cols = [_create_column_from_name(n) for n in cols]
         else:
-            raise ValueError("Can only intersect with another SchemaRDD")
+            cols = [c._jc for c in cols]
+        jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+        jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
+        return DataFrame(jdf, self.sql_ctx)
 
-    def repartition(self, numPartitions):
-        rdd = self._jschema_rdd.repartition(numPartitions, None)
-        return SchemaRDD(rdd, self.sql_ctx)
+    def filter(self, condition):
+        """ Filtering rows using the given condition::
 
-    def subtract(self, other, numPartitions=None):
-        if (other.__class__ is SchemaRDD):
-            if numPartitions is None:
-                rdd = self._jschema_rdd.subtract(other._jschema_rdd)
-            else:
-                rdd = self._jschema_rdd.subtract(other._jschema_rdd,
-                                                 numPartitions)
-            return SchemaRDD(rdd, self.sql_ctx)
+            df.filter(df.age > 15)
+            df.where(df.age > 15)
+
+        """
+        return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+
+    where = filter
+
+    def groupBy(self, *cols):
+        """ Group the [[DataFrame]] using the specified columns,
+        so we can run aggregation on them. See :class:`GroupedDataFrame`
+        for all the available aggregate functions::
+
+            df.groupBy(df.department).avg()
+            df.groupBy("department", "gender").agg({
+                "salary": "avg",
+                "age":    "max",
+            })
+        """
+        if cols and isinstance(cols[0], basestring):
+            cols = [_create_column_from_name(n) for n in cols]
         else:
-            raise ValueError("Can only subtract another SchemaRDD")
+            cols = [c._jc for c in cols]
+        jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
+        jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
+        return GroupedDataFrame(jdf, self.sql_ctx)
 
-    def sample(self, withReplacement, fraction, seed=None):
+    def agg(self, *exprs):
+        """ Aggregate on the entire [[DataFrame]] without groups
+        (shorthand for df.groupBy.agg())::
+
+            df.agg({"age": "max", "salary": "avg"})
         """
-        Return a sampled subset of this SchemaRDD.
+        return self.groupBy().agg(*exprs)
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.sample(False, 0.5, 97).count()
-        2L
+    def unionAll(self, other):
+        """ Return a new DataFrame containing union of rows in this
+        frame and another frame.
+
+        This is equivalent to `UNION ALL` in SQL.
         """
-        assert fraction >= 0.0, "Negative fraction value: %s" % fraction
-        seed = seed if seed is not None else random.randint(0, sys.maxint)
-        rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed))
-        return SchemaRDD(rdd, self.sql_ctx)
+        return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
 
-    def takeSample(self, withReplacement, num, seed=None):
-        """Return a fixed-size sampled subset of this SchemaRDD.
+    def intersect(self, other):
+        """ Return a new [[DataFrame]] containing rows only in
+        both this frame and another frame.
 
-        >>> srdd = sqlCtx.inferSchema(rdd)
-        >>> srdd.takeSample(False, 2, 97)
-        [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+        This is equivalent to `INTERSECT` in SQL.
         """
-        seed = seed if seed is not None else random.randint(0, sys.maxint)
-        with SCCallSiteSync(self.context) as css:
-            bytesInJava = self._jschema_rdd.baseSchemaRDD() \
-                .takeSampleToPython(withReplacement, num, long(seed)) \
-                .iterator()
-        cls = _create_cls(self.schema())
-        return map(cls, self._collect_iterator_through_file(bytesInJava))
+        return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+    def Except(self, other):
+        """ Return a new [[DataFrame]] containing rows in this frame
+        but not in another frame.
+
+        This is equivalent to `EXCEPT` in SQL.
+        """
+        return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+    def sample(self, withReplacement, fraction, seed=None):
+        """ Return a new DataFrame by sampling a fraction of rows. """
+        if seed is None:
+            jdf = self._jdf.sample(withReplacement, fraction)
+        else:
+            jdf = self._jdf.sample(withReplacement, fraction, seed)
+        return DataFrame(jdf, self.sql_ctx)
+
+    def addColumn(self, colName, col):
+        """ Return a new [[DataFrame]] by adding a column. """
+        return self.select('*', col.As(colName))
+
+    def removeColumn(self, colName):
+        raise NotImplemented
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+    """
+    SchemaRDD is deprecated, please use DataFrame
+    """
+
+
+def dfapi(f):
+    def _api(self):
+        name = f.__name__
+        jdf = getattr(self._jdf, name)()
+        return DataFrame(jdf, self.sql_ctx)
+    _api.__name__ = f.__name__
+    _api.__doc__ = f.__doc__
+    return _api
+
+
+class GroupedDataFrame(object):
+
+    """
+    A set of methods for aggregations on a :class:`DataFrame`,
+    created by DataFrame.groupBy().
+    """
+
+    def __init__(self, jdf, sql_ctx):
+        self._jdf = jdf
+        self.sql_ctx = sql_ctx
+
+    def agg(self, *exprs):
+        """ Compute aggregates by specifying a map from column name
+        to aggregate methods.
+
+        The available aggregate methods are `avg`, `max`, `min`,
+        `sum`, `count`.
+
+        :param exprs: list or aggregate columns or a map from column
+                      name to agregate methods.
+        """
+        if len(exprs) == 1 and isinstance(exprs[0], dict):
+            jmap = MapConverter().convert(exprs[0],
+                                          self.sql_ctx._sc._gateway._gateway_client)
+            jdf = self._jdf.agg(jmap)
+        else:
+            # Columns
+            assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
+            jdf = self._jdf.agg(*exprs)
+        return DataFrame(jdf, self.sql_ctx)
+
+    @dfapi
+    def count(self):
+        """ Count the number of rows for each group. """
+
+    @dfapi
+    def mean(self):
+        """Compute the average value for each numeric columns
+        for each group. This is an alias for `avg`."""
+
+    @dfapi
+    def avg(self):
+        """Compute the average value for each numeric columns
+        for each group."""
+
+    @dfapi
+    def max(self):
+        """Compute the max value for each numeric columns for
+        each group. """
+
+    @dfapi
+    def min(self):
+        """Compute the min value for each numeric column for
+        each group."""
+
+    @dfapi
+    def sum(self):
+        """Compute the sum for each numeric columns for each
+        group."""
+
+
+SCALA_METHOD_MAPPINGS = {
+    '=': '$eq',
+    '>': '$greater',
+    '<': '$less',
+    '+': '$plus',
+    '-': '$minus',
+    '*': '$times',
+    '/': '$div',
+    '!': '$bang',
+    '@': '$at',
+    '#': '$hash',
+    '%': '$percent',
+    '^': '$up',
+    '&': '$amp',
+    '~': '$tilde',
+    '?': '$qmark',
+    '|': '$bar',
+    '\\': '$bslash',
+    ':': '$colon',
+}
+
+
+def _create_column_from_literal(literal):
+    sc = SparkContext._active_spark_context
+    return sc._jvm.Literal.apply(literal)
+
+
+def _create_column_from_name(name):
+    sc = SparkContext._active_spark_context
+    return sc._jvm.Column(name)
+
+
+def _scalaMethod(name):
+    """ Translate operators into methodName in Scala
+
+    For example:
+    >>> _scalaMethod('+')
+    '$plus'
+    >>> _scalaMethod('>=')
+    '$greater$eq'
+    >>> _scalaMethod('cast')
+    'cast'
+    """
+    return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
+
+
+def _unary_op(name):
+    """ Create a method for given unary operator """
+    def _(self):
+        return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx)
+    return _
+
+
+def _bin_op(name):
+    """ Create a method for given binary operator """
+    def _(self, other):
+        if isinstance(other, Column):
+            jc = other._jc
+        else:
+            jc = _create_column_from_literal(other)
+        return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx)
+    return _
+
+
+def _reverse_op(name):
+    """ Create a method for binary operator (this object is on right side)
+    """
+    def _(self, other):
+        return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc),
+                      self._jdf, self.sql_ctx)
+    return _
+
+
+class Column(DataFrame):
+
+    """
+    A column in a DataFrame.
+
+    `Column` instances can be created by:
+    {{{
+    // 1. Select a column out of a DataFrame
+    df.colName
+    df["colName"]
+
+    // 2. Create from an expression
+    df["colName"] + 1
+    }}}
+    """
+
+    def __init__(self, jc, jdf=None, sql_ctx=None):
+        self._jc = jc
+        super(Column, self).__init__(jdf, sql_ctx)
+
+    # arithmetic operators
+    __neg__ = _unary_op("unary_-")
+    __add__ = _bin_op("+")
+    __sub__ = _bin_op("-")
+    __mul__ = _bin_op("*")
+    __div__ = _bin_op("/")
+    __mod__ = _bin_op("%")
+    __radd__ = _bin_op("+")
+    __rsub__ = _reverse_op("-")
+    __rmul__ = _bin_op("*")
+    __rdiv__ = _reverse_op("/")
+    __rmod__ = _reverse_op("%")
+    __abs__ = _unary_op("abs")
+    abs = _unary_op("abs")
+    sqrt = _unary_op("sqrt")
+
+    # logistic operators
+    __eq__ = _bin_op("===")
+    __ne__ = _bin_op("!==")
+    __lt__ = _bin_op("<")
+    __le__ = _bin_op("<=")
+    __ge__ = _bin_op(">=")
+    __gt__ = _bin_op(">")
+    # `and`, `or`, `not` cannot be overloaded in Python
+    And = _bin_op('&&')
+    Or = _bin_op('||')
+    Not = _unary_op('unary_!')
+
+    # bitwise operators
+    __and__ = _bin_op("&")
+    __or__ = _bin_op("|")
+    __invert__ = _unary_op("unary_~")
+    __xor__ = _bin_op("^")
+    # __lshift__ = _bin_op("<<")
+    # __rshift__ = _bin_op(">>")
+    __rand__ = _bin_op("&")
+    __ror__ = _bin_op("|")
+    __rxor__ = _bin_op("^")
+    # __rlshift__ = _reverse_op("<<")
+    # __rrshift__ = _reverse_op(">>")
+
+    # container operators
+    __contains__ = _bin_op("contains")
+    __getitem__ = _bin_op("getItem")
+    # __getattr__ = _bin_op("getField")
+
+    # string methods
+    rlike = _bin_op("rlike")
+    like = _bin_op("like")
+    startswith = _bin_op("startsWith")
+    endswith = _bin_op("endsWith")
+    upper = _unary_op("upper")
+    lower = _unary_op("lower")
+
+    def substr(self, startPos, pos):
+        if type(startPos) != type(pos):
+            raise TypeError("Can not mix the type")
+        if isinstance(startPos, (int, long)):
+
+            jc = self._jc.substr(startPos, pos)
+        elif isinstance(startPos, Column):
+            jc = self._jc.substr(startPos._jc, pos._jc)
+        else:
+            raise TypeError("Unexpected type: %s" % type(startPos))
+        return Column(jc, self._jdf, self.sql_ctx)
+
+    __getslice__ = substr
+
+    # order
+    asc = _unary_op("asc")
+    desc = _unary_op("desc")
+
+    isNull = _unary_op("isNull")
+    isNotNull = _unary_op("isNotNull")
+
+    # `as` is keyword
+    def As(self, alias):
+        return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx)
+
+    def cast(self, dataType):
+        if self.sql_ctx is None:
+            sc = SparkContext._active_spark_context
+            ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+        else:
+            ssql_ctx = self.sql_ctx._ssql_ctx
+        jdt = ssql_ctx.parseDataType(dataType.json())
+        return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
+
+
+def _aggregate_func(name):
+    """ Creat a function for aggregator by name"""
+    def _(col):
+        sc = SparkContext._active_spark_context
+        if isinstance(col, Column):
+            jcol = col._jc
+        else:
+            jcol = _create_column_from_name(col)
+        # FIXME: can not access dsl.min/max ...
+        jc = getattr(sc._jvm.org.apache.spark.sql.dsl(), name)(jcol)
+        return Column(jc)
+    return staticmethod(_)
+
+
+class Aggregator(object):
+    """
+    A collections of builtin aggregators
+    """
+    max = _aggregate_func("max")
+    min = _aggregate_func("min")
+    avg = mean = _aggregate_func("mean")
+    sum = _aggregate_func("sum")
+    first = _aggregate_func("first")
+    last = _aggregate_func("last")
+    count = _aggregate_func("count")
 
 
 def _test():

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b474fcf..e8e207a 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -806,6 +806,9 @@ class SQLTests(ReusedPySparkTestCase):
 
     def setUp(self):
         self.sqlCtx = SQLContext(self.sc)
+        self.testData = [Row(key=i, value=str(i)) for i in range(100)]
+        rdd = self.sc.parallelize(self.testData)
+        self.df = self.sqlCtx.inferSchema(rdd)
 
     def test_udf(self):
         self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
@@ -821,7 +824,7 @@ class SQLTests(ReusedPySparkTestCase):
     def test_udf_with_array_type(self):
         d = [Row(l=range(3), d={"key": range(5)})]
         rdd = self.sc.parallelize(d)
-        srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+        self.sqlCtx.inferSchema(rdd).registerTempTable("test")
         self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
         self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
         [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -839,68 +842,51 @@ class SQLTests(ReusedPySparkTestCase):
 
     def test_basic_functions(self):
         rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
-        srdd = self.sqlCtx.jsonRDD(rdd)
-        srdd.count()
-        srdd.collect()
-        srdd.schemaString()
-        srdd.schema()
+        df = self.sqlCtx.jsonRDD(rdd)
+        df.count()
+        df.collect()
+        df.schema()
 
         # cache and checkpoint
-        self.assertFalse(srdd.is_cached)
-        srdd.persist()
-        srdd.unpersist()
-        srdd.cache()
-        self.assertTrue(srdd.is_cached)
-        self.assertFalse(srdd.isCheckpointed())
-        self.assertEqual(None, srdd.getCheckpointFile())
-
-        srdd = srdd.coalesce(2, True)
-        srdd = srdd.repartition(3)
-        srdd = srdd.distinct()
-        srdd.intersection(srdd)
-        self.assertEqual(2, srdd.count())
-
-        srdd.registerTempTable("temp")
-        srdd = self.sqlCtx.sql("select foo from temp")
-        srdd.count()
-        srdd.collect()
-
-    def test_distinct(self):
-        rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10)
-        srdd = self.sqlCtx.jsonRDD(rdd)
-        self.assertEquals(srdd.getNumPartitions(), 10)
-        self.assertEquals(srdd.distinct().count(), 3)
-        result = srdd.distinct(5)
-        self.assertEquals(result.getNumPartitions(), 5)
-        self.assertEquals(result.count(), 3)
+        self.assertFalse(df.is_cached)
+        df.persist()
+        df.unpersist()
+        df.cache()
+        self.assertTrue(df.is_cached)
+        self.assertEqual(2, df.count())
+
+        df.registerTempTable("temp")
+        df = self.sqlCtx.sql("select foo from temp")
+        df.count()
+        df.collect()
 
     def test_apply_schema_to_row(self):
-        srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
-        srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
-        self.assertEqual(srdd.collect(), srdd2.collect())
+        df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+        df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+        self.assertEqual(df.collect(), df2.collect())
 
         rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
-        srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
-        self.assertEqual(10, srdd3.count())
+        df3 = self.sqlCtx.applySchema(rdd, df.schema())
+        self.assertEqual(10, df3.count())
 
     def test_serialize_nested_array_and_map(self):
         d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
         rdd = self.sc.parallelize(d)
-        srdd = self.sqlCtx.inferSchema(rdd)
-        row = srdd.first()
+        df = self.sqlCtx.inferSchema(rdd)
+        row = df.head()
         self.assertEqual(1, len(row.l))
         self.assertEqual(1, row.l[0].a)
         self.assertEqual("2", row.d["key"].d)
 
-        l = srdd.map(lambda x: x.l).first()
+        l = df.map(lambda x: x.l).first()
         self.assertEqual(1, len(l))
         self.assertEqual('s', l[0].b)
 
-        d = srdd.map(lambda x: x.d).first()
+        d = df.map(lambda x: x.d).first()
         self.assertEqual(1, len(d))
         self.assertEqual(1.0, d["key"].c)
 
-        row = srdd.map(lambda x: x.d["key"]).first()
+        row = df.map(lambda x: x.d["key"]).first()
         self.assertEqual(1.0, row.c)
         self.assertEqual("2", row.d)
 
@@ -908,26 +894,26 @@ class SQLTests(ReusedPySparkTestCase):
         d = [Row(l=[], d={}),
              Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
         rdd = self.sc.parallelize(d)
-        srdd = self.sqlCtx.inferSchema(rdd)
-        self.assertEqual([], srdd.map(lambda r: r.l).first())
-        self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
-        srdd.registerTempTable("test")
+        df = self.sqlCtx.inferSchema(rdd)
+        self.assertEqual([], df.map(lambda r: r.l).first())
+        self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+        df.registerTempTable("test")
         result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
-        self.assertEqual(1, result.first()[0])
+        self.assertEqual(1, result.head()[0])
 
-        srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
-        self.assertEqual(srdd.schema(), srdd2.schema())
-        self.assertEqual({}, srdd2.map(lambda r: r.d).first())
-        self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
-        srdd2.registerTempTable("test2")
+        df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+        self.assertEqual(df.schema(), df2.schema())
+        self.assertEqual({}, df2.map(lambda r: r.d).first())
+        self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+        df2.registerTempTable("test2")
         result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
-        self.assertEqual(1, result.first()[0])
+        self.assertEqual(1, result.head()[0])
 
     def test_struct_in_map(self):
         d = [Row(m={Row(i=1): Row(s="")})]
         rdd = self.sc.parallelize(d)
-        srdd = self.sqlCtx.inferSchema(rdd)
-        k, v = srdd.first().m.items()[0]
+        df = self.sqlCtx.inferSchema(rdd)
+        k, v = df.head().m.items()[0]
         self.assertEqual(1, k.i)
         self.assertEqual("", v.s)
 
@@ -935,9 +921,9 @@ class SQLTests(ReusedPySparkTestCase):
         row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
         self.assertEqual(1, row.asDict()['l'][0].a)
         rdd = self.sc.parallelize([row])
-        srdd = self.sqlCtx.inferSchema(rdd)
-        srdd.registerTempTable("test")
-        row = self.sqlCtx.sql("select l, d from test").first()
+        df = self.sqlCtx.inferSchema(rdd)
+        df.registerTempTable("test")
+        row = self.sqlCtx.sql("select l, d from test").head()
         self.assertEqual(1, row.asDict()["l"][0].a)
         self.assertEqual(1.0, row.asDict()['d']['key'].c)
 
@@ -945,12 +931,12 @@ class SQLTests(ReusedPySparkTestCase):
         from pyspark.tests import ExamplePoint, ExamplePointUDT
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         rdd = self.sc.parallelize([row])
-        srdd = self.sqlCtx.inferSchema(rdd)
-        schema = srdd.schema()
+        df = self.sqlCtx.inferSchema(rdd)
+        schema = df.schema()
         field = [f for f in schema.fields if f.name == "point"][0]
         self.assertEqual(type(field.dataType), ExamplePointUDT)
-        srdd.registerTempTable("labeled_point")
-        point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+        df.registerTempTable("labeled_point")
+        point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
         self.assertEqual(point, ExamplePoint(1.0, 2.0))
 
     def test_apply_schema_with_udt(self):
@@ -959,21 +945,52 @@ class SQLTests(ReusedPySparkTestCase):
         rdd = self.sc.parallelize([row])
         schema = StructType([StructField("label", DoubleType(), False),
                              StructField("point", ExamplePointUDT(), False)])
-        srdd = self.sqlCtx.applySchema(rdd, schema)
-        point = srdd.first().point
+        df = self.sqlCtx.applySchema(rdd, schema)
+        point = df.head().point
         self.assertEquals(point, ExamplePoint(1.0, 2.0))
 
     def test_parquet_with_udt(self):
         from pyspark.tests import ExamplePoint
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         rdd = self.sc.parallelize([row])
-        srdd0 = self.sqlCtx.inferSchema(rdd)
+        df0 = self.sqlCtx.inferSchema(rdd)
         output_dir = os.path.join(self.tempdir.name, "labeled_point")
-        srdd0.saveAsParquetFile(output_dir)
-        srdd1 = self.sqlCtx.parquetFile(output_dir)
-        point = srdd1.first().point
+        df0.saveAsParquetFile(output_dir)
+        df1 = self.sqlCtx.parquetFile(output_dir)
+        point = df1.head().point
         self.assertEquals(point, ExamplePoint(1.0, 2.0))
 
+    def test_column_operators(self):
+        from pyspark.sql import Column, LongType
+        ci = self.df.key
+        cs = self.df.value
+        c = ci == cs
+        self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+        rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+        self.assertTrue(all(isinstance(c, Column) for c in rcc))
+        cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+        self.assertTrue(all(isinstance(c, Column) for c in cb))
+        cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci)
+        self.assertTrue(all(isinstance(c, Column) for c in cbit))
+        css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+        self.assertTrue(all(isinstance(c, Column) for c in css))
+        self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+    def test_column_select(self):
+        df = self.df
+        self.assertEqual(self.testData, df.select("*").collect())
+        self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+        self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+    def test_aggregator(self):
+        df = self.df
+        g = df.groupBy()
+        self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+        self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+        # TODO(davies): fix aggregators
+        from pyspark.sql import Aggregator as Agg
+        # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
+
 
 class InputFormatTests(ReusedPySparkTestCase):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
index 22941ed..4c5fb3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
@@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] {
       .toSet
 
     plan transform {
-      case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance
+      case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance()
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 3035d93..f388cd5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression {
  * For example the SQL expression "1 + 1 AS a" could be represented as follows:
  *  Alias(Add(Literal(1), Literal(1), "a")()
  *
+ * Note that exprId and qualifiers are in a separate parameter list because
+ * we only pattern match on child and name.
+ *
  * @param child the computation being performed
  * @param name the name to be associated with the result of computing [[child]].
  * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 613f4bb..5dc0539 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -17,9 +17,24 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+object JoinType {
+  def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
+    case "inner" => Inner
+    case "outer" | "full" | "fullouter" => FullOuter
+    case "leftouter" | "left" => LeftOuter
+    case "rightouter" | "right" => RightOuter
+    case "leftsemi" => LeftSemi
+  }
+}
+
 sealed abstract class JoinType
+
 case object Inner extends JoinType
+
 case object LeftOuter extends JoinType
+
 case object RightOuter extends JoinType
+
 case object FullOuter extends JoinType
+
 case object LeftSemi extends JoinType

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
index 1976998..d90af45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala
@@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.types.{StructType, StructField}
 
 object LocalRelation {
-  def apply(output: Attribute*) =
-    new LocalRelation(output)
+  def apply(output: Attribute*): LocalRelation = new LocalRelation(output)
+
+  def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation(
+    StructType(output1 +: output).toAttributes
+  )
 }
 
 case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
index e715d94..bc22f68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -80,7 +80,7 @@ private[sql] trait CacheManager {
    * the in-memory columnar representation of the underlying table is expensive.
    */
   private[sql] def cacheQuery(
-      query: SchemaRDD,
+      query: DataFrame,
       tableName: Option[String] = None,
       storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
@@ -100,7 +100,7 @@ private[sql] trait CacheManager {
   }
 
   /** Removes the data for the given SchemaRDD from the cache */
-  private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock {
+  private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
     require(dataIndex >= 0, s"Table $query is not cached.")
@@ -110,7 +110,7 @@ private[sql] trait CacheManager {
 
   /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */
   private[sql] def tryUncacheQuery(
-      query: SchemaRDD,
+      query: DataFrame,
       blocking: Boolean = true): Boolean = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -123,7 +123,7 @@ private[sql] trait CacheManager {
   }
 
   /** Optionally returns cached data for the given SchemaRDD */
-  private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock {
+  private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
     lookupCachedData(query.queryExecution.analyzed)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/119f45d6/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
new file mode 100644
index 0000000..7fc8347
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -0,0 +1,528 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
+import org.apache.spark.sql.types._
+
+
+object Column {
+  def unapply(col: Column): Option[Expression] = Some(col.expr)
+
+  def apply(colName: String): Column = new Column(colName)
+}
+
+
+/**
+ * A column in a [[DataFrame]].
+ *
+ * `Column` instances can be created by:
+ * {{{
+ *   // 1. Select a column out of a DataFrame
+ *   df("colName")
+ *
+ *   // 2. Create a literal expression
+ *   Literal(1)
+ *
+ *   // 3. Create new columns from
+ * }}}
+ *
+ */
+// TODO: Improve documentation.
+class Column(
+    sqlContext: Option[SQLContext],
+    plan: Option[LogicalPlan],
+    val expr: Expression)
+  extends DataFrame(sqlContext, plan) with ExpressionApi {
+
+  /** Turn a Catalyst expression into a `Column`. */
+  protected[sql] def this(expr: Expression) = this(None, None, expr)
+
+  /**
+   * Create a new `Column` expression based on a column or attribute name.
+   * The resolution of this is the same as SQL. For example:
+   *
+   * - "colName" becomes an expression selecting the column named "colName".
+   * - "*" becomes an expression selecting all columns.
+   * - "df.*" becomes an expression selecting all columns in data frame "df".
+   */
+  def this(name: String) = this(name match {
+    case "*" => Star(None)
+    case _ if name.endsWith(".*") => Star(Some(name.substring(0, name.length - 2)))
+    case _ => UnresolvedAttribute(name)
+  })
+
+  override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined
+
+  /**
+   * An implicit conversion function internal to this class. This function creates a new Column
+   * based on an expression. If the expression itself is not named, it aliases the expression
+   * by calling it "col".
+   */
+  private[this] implicit def toColumn(expr: Expression): Column = {
+    val projectedPlan = plan.map { p =>
+      Project(Seq(expr match {
+        case named: NamedExpression => named
+        case unnamed: Expression => Alias(unnamed, "col")()
+      }), p)
+    }
+    new Column(sqlContext, projectedPlan, expr)
+  }
+
+  /**
+   * Unary minus, i.e. negate the expression.
+   * {{{
+   *   // Select the amount column and negates all values.
+   *   df.select( -df("amount") )
+   * }}}
+   */
+  override def unary_- : Column = UnaryMinus(expr)
+
+  /**
+   * Bitwise NOT.
+   * {{{
+   *   // Select the flags column and negate every bit.
+   *   df.select( ~df("flags") )
+   * }}}
+   */
+  override def unary_~ : Column = BitwiseNot(expr)
+
+  /**
+   * Invert a boolean expression, i.e. NOT.
+   * {{
+   *   // Select rows that are not active (isActive === false)
+   *   df.select( !df("isActive") )
+   * }}
+   */
+  override def unary_! : Column = Not(expr)
+
+
+  /**
+   * Equality test with an expression.
+   * {{{
+   *   // The following two both select rows in which colA equals colB.
+   *   df.select( df("colA") === df("colB") )
+   *   df.select( df("colA".equalTo(df("colB")) )
+   * }}}
+   */
+  override def === (other: Column): Column = EqualTo(expr, other.expr)
+
+  /**
+   * Equality test with a literal value.
+   * {{{
+   *   // The following two both select rows in which colA is "Zaharia".
+   *   df.select( df("colA") === "Zaharia")
+   *   df.select( df("colA".equalTo("Zaharia") )
+   * }}}
+   */
+  override def === (literal: Any): Column = this === Literal.anyToLiteral(literal)
+
+  /**
+   * Equality test with an expression.
+   * {{{
+   *   // The following two both select rows in which colA equals colB.
+   *   df.select( df("colA") === df("colB") )
+   *   df.select( df("colA".equalTo(df("colB")) )
+   * }}}
+   */
+  override def equalTo(other: Column): Column = this === other
+
+  /**
+   * Equality test with a literal value.
+   * {{{
+   *   // The following two both select rows in which colA is "Zaharia".
+   *   df.select( df("colA") === "Zaharia")
+   *   df.select( df("colA".equalTo("Zaharia") )
+   * }}}
+   */
+  override def equalTo(literal: Any): Column = this === literal
+
+  /**
+   * Inequality test with an expression.
+   * {{{
+   *   // The following two both select rows in which colA does not equal colB.
+   *   df.select( df("colA") !== df("colB") )
+   *   df.select( !(df("colA") === df("colB")) )
+   * }}}
+   */
+  override def !== (other: Column): Column = Not(EqualTo(expr, other.expr))
+
+  /**
+   * Inequality test with a literal value.
+   * {{{
+   *   // The following two both select rows in which colA does not equal equal 15.
+   *   df.select( df("colA") !== 15 )
+   *   df.select( !(df("colA") === 15) )
+   * }}}
+   */
+  override def !== (literal: Any): Column = this !== Literal.anyToLiteral(literal)
+
+  /**
+   * Greater than an expression.
+   * {{{
+   *   // The following selects people older than 21.
+   *   people.select( people("age") > Literal(21) )
+   * }}}
+   */
+  override def > (other: Column): Column = GreaterThan(expr, other.expr)
+
+  /**
+   * Greater than a literal value.
+   * {{{
+   *   // The following selects people older than 21.
+   *   people.select( people("age") > 21 )
+   * }}}
+   */
+  override def > (literal: Any): Column = this > Literal.anyToLiteral(literal)
+
+  /**
+   * Less than an expression.
+   * {{{
+   *   // The following selects people younger than 21.
+   *   people.select( people("age") < Literal(21) )
+   * }}}
+   */
+  override def < (other: Column): Column = LessThan(expr, other.expr)
+
+  /**
+   * Less than a literal value.
+   * {{{
+   *   // The following selects people younger than 21.
+   *   people.select( people("age") < 21 )
+   * }}}
+   */
+  override def < (literal: Any): Column = this < Literal.anyToLiteral(literal)
+
+  /**
+   * Less than or equal to an expression.
+   * {{{
+   *   // The following selects people age 21 or younger than 21.
+   *   people.select( people("age") <= Literal(21) )
+   * }}}
+   */
+  override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr)
+
+  /**
+   * Less than or equal to a literal value.
+   * {{{
+   *   // The following selects people age 21 or younger than 21.
+   *   people.select( people("age") <= 21 )
+   * }}}
+   */
+  override def <= (literal: Any): Column = this <= Literal.anyToLiteral(literal)
+
+  /**
+   * Greater than or equal to an expression.
+   * {{{
+   *   // The following selects people age 21 or older than 21.
+   *   people.select( people("age") >= Literal(21) )
+   * }}}
+   */
+  override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr)
+
+  /**
+   * Greater than or equal to a literal value.
+   * {{{
+   *   // The following selects people age 21 or older than 21.
+   *   people.select( people("age") >= 21 )
+   * }}}
+   */
+  override def >= (literal: Any): Column = this >= Literal.anyToLiteral(literal)
+
+  /**
+   * Equality test with an expression that is safe for null values.
+   */
+  override def <=> (other: Column): Column = EqualNullSafe(expr, other.expr)
+
+  /**
+   * Equality test with a literal value that is safe for null values.
+   */
+  override def <=> (literal: Any): Column = this <=> Literal.anyToLiteral(literal)
+
+  /**
+   * True if the current expression is null.
+   */
+  override def isNull: Column = IsNull(expr)
+
+  /**
+   * True if the current expression is NOT null.
+   */
+  override def isNotNull: Column = IsNotNull(expr)
+
+  /**
+   * Boolean OR with an expression.
+   * {{{
+   *   // The following selects people that are in school or employed.
+   *   people.select( people("inSchool") || people("isEmployed") )
+   * }}}
+   */
+  override def || (other: Column): Column = Or(expr, other.expr)
+
+  /**
+   * Boolean OR with a literal value.
+   * {{{
+   *   // The following selects everything.
+   *   people.select( people("inSchool") || true )
+   * }}}
+   */
+  override def || (literal: Boolean): Column = this || Literal.anyToLiteral(literal)
+
+  /**
+   * Boolean AND with an expression.
+   * {{{
+   *   // The following selects people that are in school and employed at the same time.
+   *   people.select( people("inSchool") && people("isEmployed") )
+   * }}}
+   */
+  override def && (other: Column): Column = And(expr, other.expr)
+
+  /**
+   * Boolean AND with a literal value.
+   * {{{
+   *   // The following selects people that are in school.
+   *   people.select( people("inSchool") && true )
+   * }}}
+   */
+  override def && (literal: Boolean): Column = this && Literal.anyToLiteral(literal)
+
+  /**
+   * Bitwise AND with an expression.
+   */
+  override def & (other: Column): Column = BitwiseAnd(expr, other.expr)
+
+  /**
+   * Bitwise AND with a literal value.
+   */
+  override def & (literal: Any): Column = this & Literal.anyToLiteral(literal)
+
+  /**
+   * Bitwise OR with an expression.
+   */
+  override def | (other: Column): Column = BitwiseOr(expr, other.expr)
+
+  /**
+   * Bitwise OR with a literal value.
+   */
+  override def | (literal: Any): Column = this | Literal.anyToLiteral(literal)
+
+  /**
+   * Bitwise XOR with an expression.
+   */
+  override def ^ (other: Column): Column = BitwiseXor(expr, other.expr)
+
+  /**
+   * Bitwise XOR with a literal value.
+   */
+  override def ^ (literal: Any): Column = this ^ Literal.anyToLiteral(literal)
+
+  /**
+   * Sum of this expression and another expression.
+   * {{{
+   *   // The following selects the sum of a person's height and weight.
+   *   people.select( people("height") + people("weight") )
+   * }}}
+   */
+  override def + (other: Column): Column = Add(expr, other.expr)
+
+  /**
+   * Sum of this expression and another expression.
+   * {{{
+   *   // The following selects the sum of a person's height and 10.
+   *   people.select( people("height") + 10 )
+   * }}}
+   */
+  override def + (literal: Any): Column = this + Literal.anyToLiteral(literal)
+
+  /**
+   * Subtraction. Substract the other expression from this expression.
+   * {{{
+   *   // The following selects the difference between people's height and their weight.
+   *   people.select( people("height") - people("weight") )
+   * }}}
+   */
+  override def - (other: Column): Column = Subtract(expr, other.expr)
+
+  /**
+   * Subtraction. Substract a literal value from this expression.
+   * {{{
+   *   // The following selects a person's height and substract it by 10.
+   *   people.select( people("height") - 10 )
+   * }}}
+   */
+  override def - (literal: Any): Column = this - Literal.anyToLiteral(literal)
+
+  /**
+   * Multiply this expression and another expression.
+   * {{{
+   *   // The following multiplies a person's height by their weight.
+   *   people.select( people("height") * people("weight") )
+   * }}}
+   */
+  override def * (other: Column): Column = Multiply(expr, other.expr)
+
+  /**
+   * Multiply this expression and a literal value.
+   * {{{
+   *   // The following multiplies a person's height by 10.
+   *   people.select( people("height") * 10 )
+   * }}}
+   */
+  override def * (literal: Any): Column = this * Literal.anyToLiteral(literal)
+
+  /**
+   * Divide this expression by another expression.
+   * {{{
+   *   // The following divides a person's height by their weight.
+   *   people.select( people("height") / people("weight") )
+   * }}}
+   */
+  override def / (other: Column): Column = Divide(expr, other.expr)
+
+  /**
+   * Divide this expression by a literal value.
+   * {{{
+   *   // The following divides a person's height by 10.
+   *   people.select( people("height") / 10 )
+   * }}}
+   */
+  override def / (literal: Any): Column = this / Literal.anyToLiteral(literal)
+
+  /**
+   * Modulo (a.k.a. remainder) expression.
+   */
+  override def % (other: Column): Column = Remainder(expr, other.expr)
+
+  /**
+   * Modulo (a.k.a. remainder) expression.
+   */
+  override def % (literal: Any): Column = this % Literal.anyToLiteral(literal)
+
+
+  /**
+   * A boolean expression that is evaluated to true if the value of this expression is contained
+   * by the evaluated values of the arguments.
+   */
+  @scala.annotation.varargs
+  override def in(list: Column*): Column = In(expr, list.map(_.expr))
+
+  override def like(other: Column): Column = Like(expr, other.expr)
+
+  override def like(literal: String): Column = this.like(Literal.anyToLiteral(literal))
+
+  override def rlike(other: Column): Column = RLike(expr, other.expr)
+
+  override def rlike(literal: String): Column = this.rlike(Literal.anyToLiteral(literal))
+
+
+  override def getItem(ordinal: Int): Column = GetItem(expr, LiteralExpr(ordinal))
+
+  override def getItem(ordinal: Column): Column = GetItem(expr, ordinal.expr)
+
+  override def getField(fieldName: String): Column = GetField(expr, fieldName)
+
+
+  override def substr(startPos: Column, len: Column): Column =
+    Substring(expr, startPos.expr, len.expr)
+
+  override def substr(startPos: Int, len: Int): Column =
+    this.substr(Literal.anyToLiteral(startPos), Literal.anyToLiteral(len))
+
+  override def contains(other: Column): Column = Contains(expr, other.expr)
+
+  override def contains(literal: Any): Column = this.contains(Literal.anyToLiteral(literal))
+
+
+  override def startsWith(other: Column): Column = StartsWith(expr, other.expr)
+
+  override def startsWith(literal: String): Column = this.startsWith(Literal.anyToLiteral(literal))
+
+  override def endsWith(other: Column): Column = EndsWith(expr, other.expr)
+
+  override def endsWith(literal: String): Column = this.endsWith(Literal.anyToLiteral(literal))
+
+  override def as(alias: String): Column = Alias(expr, alias)()
+
+  override def cast(to: DataType): Column = Cast(expr, to)
+
+  override def desc: Column = SortOrder(expr, Descending)
+
+  override def asc: Column = SortOrder(expr, Ascending)
+}
+
+
+class ColumnName(name: String) extends Column(name) {
+
+  /** Creates a new AttributeReference of type boolean */
+  def boolean: StructField = StructField(name, BooleanType)
+
+  /** Creates a new AttributeReference of type byte */
+  def byte: StructField = StructField(name, ByteType)
+
+  /** Creates a new AttributeReference of type short */
+  def short: StructField = StructField(name, ShortType)
+
+  /** Creates a new AttributeReference of type int */
+  def int: StructField = StructField(name, IntegerType)
+
+  /** Creates a new AttributeReference of type long */
+  def long: StructField = StructField(name, LongType)
+
+  /** Creates a new AttributeReference of type float */
+  def float: StructField = StructField(name, FloatType)
+
+  /** Creates a new AttributeReference of type double */
+  def double: StructField = StructField(name, DoubleType)
+
+  /** Creates a new AttributeReference of type string */
+  def string: StructField = StructField(name, StringType)
+
+  /** Creates a new AttributeReference of type date */
+  def date: StructField = StructField(name, DateType)
+
+  /** Creates a new AttributeReference of type decimal */
+  def decimal: StructField = StructField(name, DecimalType.Unlimited)
+
+  /** Creates a new AttributeReference of type decimal */
+  def decimal(precision: Int, scale: Int): StructField =
+    StructField(name, DecimalType(precision, scale))
+
+  /** Creates a new AttributeReference of type timestamp */
+  def timestamp: StructField = StructField(name, TimestampType)
+
+  /** Creates a new AttributeReference of type binary */
+  def binary: StructField = StructField(name, BinaryType)
+
+  /** Creates a new AttributeReference of type array */
+  def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType))
+
+  /** Creates a new AttributeReference of type map */
+  def map(keyType: DataType, valueType: DataType): StructField =
+    map(MapType(keyType, valueType))
+
+  def map(mapType: MapType): StructField = StructField(name, mapType)
+
+  /** Creates a new AttributeReference of type struct */
+  def struct(fields: StructField*): StructField = struct(StructType(fields))
+
+  def struct(structType: StructType): StructField = StructField(name, structType)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org