You are viewing a plain text version of this content. The canonical link for it is here.
Posted to user@spark.apache.org by Thomas Wang <w...@datability.io> on 2023/04/23 15:50:11 UTC

Spark Aggregator with ARRAY input and ARRAY output

Hi Spark Community,

I'm trying to implement a custom Spark Aggregator (a subclass to
org.apache.spark.sql.expressions.Aggregator). Correct me if I'm wrong, but
I'm assuming I will be able to use it as an aggregation function like SUM.

What I'm trying to do is that I have a column of ARRAY<BOOLEAN> and I would
like to GROUP BY another column and perform element-wise SUM if the boolean
flag is set to True. The result of such aggregation should return
ARRAY<LONG>.

Here is my implementation so far:

package mypackage.udf;

import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.expressions.Aggregator;

import java.util.ArrayList;
import java.util.List;

public class ElementWiseAgg extends Aggregator<List<Boolean>,
List<Long>, List<Long>> {

    @Override
    public List<Long> zero() {
        return new ArrayList<>();
    }

    @Override
    public List<Long> reduce(List<Long> b, List<Boolean> a) {
        if (a == null) return b;
        int diff = a.size() - b.size();
        for (int i = 0; i < diff; i++) {
            b.add(0L);
        }
        for (int i = 0; i < a.size(); i++) {
            if (a.get(i)) b.set(i, b.get(i) + 1);
        }
        return b;
    }

    @Override
    public List<Long> merge(List<Long> b1, List<Long> b2) {
        List<Long> longer;
        List<Long> shorter;
        if (b1.size() > b2.size()) {
            longer = b1;
            shorter = b2;
        } else {
            longer = b2;
            shorter = b1;
        }
        for (int i = 0; i < shorter.size(); i++) {
            longer.set(i, longer.get(i) + shorter.get(i));
        }
        return longer;
    }

    @Override
    public List<Long> finish(List<Long> reduction) {
        return reduction;
    }

    @Override
    public Encoder<List<Long>> bufferEncoder() {
        return null;
    }

    @Override
    public Encoder<List<Long>> outputEncoder() {
        return null;
    }
}

The part I'm not quite sure is how to override bufferEncoder and
outputEncoder. The default Encoders list does not provide encoding for List.

Can someone point me to the right direction? Thanks!


Thomas

Re: Spark Aggregator with ARRAY input and ARRAY output

Posted by Thomas Wang <w...@datability.io>.
Thanks Raghavendra,

Could you be more specific about how I can use ExpressionEncoder()? More
specifically, how can I conform to the return type of Encoder<List<Long>>?

Thomas

On Sun, Apr 23, 2023 at 9:42 AM Raghavendra Ganesh <ra...@gmail.com>
wrote:

> For simple array types setting encoder to ExpressionEncoder() should work.
> --
> Raghavendra
>
>
> On Sun, Apr 23, 2023 at 9:20 PM Thomas Wang <w...@datability.io> wrote:
>
>> Hi Spark Community,
>>
>> I'm trying to implement a custom Spark Aggregator (a subclass to
>> org.apache.spark.sql.expressions.Aggregator). Correct me if I'm wrong,
>> but I'm assuming I will be able to use it as an aggregation function like
>> SUM.
>>
>> What I'm trying to do is that I have a column of ARRAY<BOOLEAN> and I
>> would like to GROUP BY another column and perform element-wise SUM if
>> the boolean flag is set to True. The result of such aggregation should
>> return ARRAY<LONG>.
>>
>> Here is my implementation so far:
>>
>> package mypackage.udf;
>>
>> import org.apache.spark.sql.Encoder;
>> import org.apache.spark.sql.expressions.Aggregator;
>>
>> import java.util.ArrayList;
>> import java.util.List;
>>
>> public class ElementWiseAgg extends Aggregator<List<Boolean>, List<Long>, List<Long>> {
>>
>>     @Override
>>     public List<Long> zero() {
>>         return new ArrayList<>();
>>     }
>>
>>     @Override
>>     public List<Long> reduce(List<Long> b, List<Boolean> a) {
>>         if (a == null) return b;
>>         int diff = a.size() - b.size();
>>         for (int i = 0; i < diff; i++) {
>>             b.add(0L);
>>         }
>>         for (int i = 0; i < a.size(); i++) {
>>             if (a.get(i)) b.set(i, b.get(i) + 1);
>>         }
>>         return b;
>>     }
>>
>>     @Override
>>     public List<Long> merge(List<Long> b1, List<Long> b2) {
>>         List<Long> longer;
>>         List<Long> shorter;
>>         if (b1.size() > b2.size()) {
>>             longer = b1;
>>             shorter = b2;
>>         } else {
>>             longer = b2;
>>             shorter = b1;
>>         }
>>         for (int i = 0; i < shorter.size(); i++) {
>>             longer.set(i, longer.get(i) + shorter.get(i));
>>         }
>>         return longer;
>>     }
>>
>>     @Override
>>     public List<Long> finish(List<Long> reduction) {
>>         return reduction;
>>     }
>>
>>     @Override
>>     public Encoder<List<Long>> bufferEncoder() {
>>         return null;
>>     }
>>
>>     @Override
>>     public Encoder<List<Long>> outputEncoder() {
>>         return null;
>>     }
>> }
>>
>> The part I'm not quite sure is how to override bufferEncoder and
>> outputEncoder. The default Encoders list does not provide encoding for
>> List.
>>
>> Can someone point me to the right direction? Thanks!
>>
>>
>> Thomas
>>
>>
>>

Re: Spark Aggregator with ARRAY input and ARRAY output

Posted by Raghavendra Ganesh <ra...@gmail.com>.
For simple array types setting encoder to ExpressionEncoder() should work.
--
Raghavendra


On Sun, Apr 23, 2023 at 9:20 PM Thomas Wang <w...@datability.io> wrote:

> Hi Spark Community,
>
> I'm trying to implement a custom Spark Aggregator (a subclass to
> org.apache.spark.sql.expressions.Aggregator). Correct me if I'm wrong,
> but I'm assuming I will be able to use it as an aggregation function like
> SUM.
>
> What I'm trying to do is that I have a column of ARRAY<BOOLEAN> and I
> would like to GROUP BY another column and perform element-wise SUM if the
> boolean flag is set to True. The result of such aggregation should return
> ARRAY<LONG>.
>
> Here is my implementation so far:
>
> package mypackage.udf;
>
> import org.apache.spark.sql.Encoder;
> import org.apache.spark.sql.expressions.Aggregator;
>
> import java.util.ArrayList;
> import java.util.List;
>
> public class ElementWiseAgg extends Aggregator<List<Boolean>, List<Long>, List<Long>> {
>
>     @Override
>     public List<Long> zero() {
>         return new ArrayList<>();
>     }
>
>     @Override
>     public List<Long> reduce(List<Long> b, List<Boolean> a) {
>         if (a == null) return b;
>         int diff = a.size() - b.size();
>         for (int i = 0; i < diff; i++) {
>             b.add(0L);
>         }
>         for (int i = 0; i < a.size(); i++) {
>             if (a.get(i)) b.set(i, b.get(i) + 1);
>         }
>         return b;
>     }
>
>     @Override
>     public List<Long> merge(List<Long> b1, List<Long> b2) {
>         List<Long> longer;
>         List<Long> shorter;
>         if (b1.size() > b2.size()) {
>             longer = b1;
>             shorter = b2;
>         } else {
>             longer = b2;
>             shorter = b1;
>         }
>         for (int i = 0; i < shorter.size(); i++) {
>             longer.set(i, longer.get(i) + shorter.get(i));
>         }
>         return longer;
>     }
>
>     @Override
>     public List<Long> finish(List<Long> reduction) {
>         return reduction;
>     }
>
>     @Override
>     public Encoder<List<Long>> bufferEncoder() {
>         return null;
>     }
>
>     @Override
>     public Encoder<List<Long>> outputEncoder() {
>         return null;
>     }
> }
>
> The part I'm not quite sure is how to override bufferEncoder and
> outputEncoder. The default Encoders list does not provide encoding for
> List.
>
> Can someone point me to the right direction? Thanks!
>
>
> Thomas
>
>
>