Created
May 16, 2017 12:06
-
-
Save lokkju/06323e88746c85b2ce4de3ea9cdef9bc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.apache.spark.sql.catalyst.expressions | |
import org.apache.spark.sql._ | |
import org.apache.spark.sql.catalyst._ | |
import org.apache.spark.sql.catalyst.analysis._ | |
import org.apache.spark.sql.catalyst.expressions.aggregate._ | |
import org.apache.spark.sql.types._ | |
import scala.collection.mutable | |
@ExpressionDescription( | |
usage = "_FUNC_(expr) - Collects and returns a set of unique elements with a limit on the number of elements.") | |
case class CollectSetLimit( | |
child: Expression, limit: Int, | |
mutableAggBufferOffset: Int = 0, | |
inputAggBufferOffset: Int = 0) extends Collect { | |
var attemptedUpdateCount = 0 | |
def this(child: Expression, limit: Int) = this(child, limit, 0, 0) | |
override def checkInputDataTypes(): TypeCheckResult = { | |
if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { | |
TypeCheckResult.TypeCheckSuccess | |
} else { | |
TypeCheckResult.TypeCheckFailure("collect_set_limit() cannot have map type data") | |
} | |
} | |
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | |
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | |
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | |
copy(inputAggBufferOffset = newInputAggBufferOffset) | |
override def prettyName: String = "collect_set_limit" | |
override protected[this] val buffer: mutable.HashSet[Any] = mutable.HashSet.empty | |
override def update(b: InternalRow, input: InternalRow): Unit = { | |
attemptedUpdateCount += 1 | |
if(buffer.size < limit) { | |
buffer += child.eval(input) | |
} else if (attemptedUpdateCount % limit == 0) { | |
// insert log statement, or other code, if needed | |
} | |
} | |
} | |
/** | |
* Collect a list of elements. | |
*/ | |
@ExpressionDescription( | |
usage = "_FUNC_(expr) - Collects and returns a list of non-unique elements with a limit on the number of elements.") | |
case class CollectListLimit( | |
child: Expression, | |
limit: Int, | |
mutableAggBufferOffset: Int = 0, | |
inputAggBufferOffset: Int = 0) extends Collect with Logging { | |
var attemptedUpdateCount = 0 | |
def this(child: Expression, limit: Int) = this(child, limit, 0, 0) | |
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = | |
copy(mutableAggBufferOffset = newMutableAggBufferOffset) | |
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = | |
copy(inputAggBufferOffset = newInputAggBufferOffset) | |
override def prettyName: String = "collect_list_limit" | |
override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty | |
override def update(b: InternalRow, input: InternalRow): Unit = { | |
attemptedUpdateCount += 1 | |
if(buffer.size < limit) { | |
buffer += child.eval(input) | |
} else if (attemptedUpdateCount % limit == 0) { | |
logWarning(s"Reached max buffer size: $attemptedUpdateCount/$limit records [${input.toString}]") | |
} | |
} | |
} | |
object collect_limit { | |
def collect_set_limit(e: Column, limit: Int): Column = withAggregateFunction { CollectSetLimit(e.expr, limit) } | |
def collect_list_limit(e: Column, limit: Int): Column = withAggregateFunction { CollectListLimit(e.expr, limit) } | |
private def withAggregateFunction( | |
func: AggregateFunction, | |
isDistinct: Boolean = false): Column = { | |
Column(func.toAggregateExpression(isDistinct)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment