Last active
November 25, 2022 08:10
-
-
Save zyxue/69100fa70e2b26abc46d229022f2d1ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.collection.mutable.Map | |
import org.apache.spark.sql.expressions.Aggregator | |
import org.apache.spark.sql.Encoder | |
import org.apache.spark.sql.Encoders | |
import spark.implicits._ | |
import org.apache.spark.sql.types._ | |
case class Span( | |
ref_name: String, | |
bc: String, | |
beg: Int, | |
end: Int, | |
read_count: Int) | |
val spanSchema = StructType( | |
Array( | |
StructField("ref_name", StringType, true), | |
StructField("bc", StringType, true), | |
StructField("beg", IntegerType, true), | |
StructField("end", IntegerType, true), | |
StructField("read_count", IntegerType, true) | |
) | |
) | |
object CalcBreakPoints extends Aggregator[Span, Map[Int, Int], Array[Int]] { | |
// Reduce an array of spans to coverage, then to break points | |
// A zero value for this aggregation. Should satisfy the property that any b + zero = b | |
def zero: Map[Int, Int] = Map[Int, Int]() | |
// Combine two values to produce a new value. For performance, the function | |
// may modify `buffer` and return it instead of constructing a new object | |
def reduce(buffer: Map[Int, Int], span: Span): Map[Int, Int] = { | |
(span.beg until span.end).foreach( | |
i => buffer += (i -> (buffer.getOrElse[Int](i, 0) + 1))) | |
buffer | |
} | |
// Merge two intermediate values | |
def merge(b1: Map[Int, Int], b2: Map[Int, Int]): Map[Int, Int] = { | |
b2.foreach { | |
case (key, value) => b1 += (key -> (value + b1.getOrElse[Int](key, 0))) | |
} | |
b1 | |
} | |
// Transform the output of the reduction, convert to BreakPoint | |
def finish(coverage: Map[Int, Int]): Array[Int] = { | |
val cov_cutoff = 20; | |
val f = (i: Int) => if (i >= cov_cutoff) 1 else 0 | |
val coords = coverage.keys.toArray.sorted; | |
val bp = coords.slice(1, coords.length).map( | |
c => { | |
val current = f(coverage(c)) | |
val previous_step = f(coverage.getOrElse(c - 1, 0)) | |
(c, current - previous_step) | |
}) | |
.filter { case(c, d) => d != 0} | |
.map {case (c, d) => c} | |
// val qualified = qualified.slice(1, qualified.length).map { | |
// case (c, b) => | |
// c => if (coverage(c) >= read_count_cutoff) (c, 1) else (c, 0)) | |
// val diff = coords.slice(1, coords.length).map(c => (c, (reduction(c) - reduction.getOrElse(c - 1, 0)))) | |
// val bp = diff.filter {case (c, d) => d != 0} map {case (c, d) => c} | |
bp | |
} | |
// Specifies the Encoder for the intermediate value type | |
def bufferEncoder: Encoder[Map[Int, Int]] = Encoders.kryo | |
// Specifies the Encoder for the final output value type | |
def outputEncoder: Encoder[Array[Int]] = Encoders.kryo | |
} | |
val ds = spark.read.option("sep", "\t").schema(spanSchema).csv("/projects/btl/zxue/assembly_correction/celegans/toy_cov.csv").as[Span] | |
val cc = CalcBreakPoints.toColumn.name("bp") | |
val res = ds.groupByKey(a => a.ref_name).agg(cc) | |
res.write.format("parquet").save("./lele.parquet") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment