Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
spark
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
cs525-sp18-g07
spark
Commits
f7149c5e
Commit
f7149c5e
authored
12 years ago
by
Imran Rashid
Committed by
Matei Zaharia
12 years ago
Browse files
Options
Downloads
Patches
Plain Diff
tasks cannot access value of accumulator
parent
244cbbe3
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
core/src/main/scala/spark/Accumulators.scala
+9
-3
9 additions, 3 deletions
core/src/main/scala/spark/Accumulators.scala
core/src/test/scala/spark/AccumulatorSuite.scala
+12
-53
12 additions, 53 deletions
core/src/test/scala/spark/AccumulatorSuite.scala
with
21 additions
and
56 deletions
core/src/main/scala/spark/Accumulators.scala
+
9
−
3
View file @
f7149c5e
...
...
@@ -11,7 +11,7 @@ class Accumulable[T,R] (
val
id
=
Accumulators
.
newId
@transient
var
value_
=
initialValue
// Current value on master
private
var
value_
=
initialValue
// Current value on master
val
zero
=
param
.
zero
(
initialValue
)
// Zero value to be passed to workers
var
deserialized
=
false
...
...
@@ -30,7 +30,13 @@ class Accumulable[T,R] (
* @param term the other Accumulable that will get merged with this
*/
def
++=
(
term
:
T
)
{
value_
=
param
.
addInPlace
(
value_
,
term
)}
def
value
=
this
.
value_
def
value
=
{
if
(!
deserialized
)
value_
else
throw
new
UnsupportedOperationException
(
"Can't use read value in task"
)
}
private
[
spark
]
def
localValue
=
value_
def
value_=
(
t
:
T
)
{
if
(!
deserialized
)
value_
=
t
else
throw
new
UnsupportedOperationException
(
"Can't use value_= in task"
)
...
...
@@ -126,7 +132,7 @@ private object Accumulators {
def
values
:
Map
[
Long
,
Any
]
=
synchronized
{
val
ret
=
Map
[
Long
,
Any
]()
for
((
id
,
accum
)
<-
localAccums
.
getOrElse
(
Thread
.
currentThread
,
Map
()))
{
ret
(
id
)
=
accum
.
v
alue
ret
(
id
)
=
accum
.
localV
alue
}
return
ret
}
...
...
This diff is collapsed.
Click to expand it.
core/src/test/scala/spark/AccumulatorSuite.scala
+
12
−
53
View file @
f7149c5e
...
...
@@ -63,60 +63,19 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
}
test
(
"value readable in tasks"
)
{
import
spark.util.Vector
//stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go
//really easy data
val
N
=
10000
// Number of data points
val
D
=
10
// Numer of dimensions
val
R
=
0.7
// Scaling factor
val
ITERATIONS
=
5
val
rand
=
new
Random
(
42
)
case
class
DataPoint
(
x
:
Vector
,
y
:
Double
)
def
generateData
=
{
def
generatePoint
(
i
:
Int
)
=
{
val
y
=
if
(
i
%
2
==
0
)
-
1
else
1
val
goodX
=
Vector
(
D
,
_
=>
0.0001
*
rand
.
nextGaussian
()
+
y
)
val
noiseX
=
Vector
(
D
,
_
=>
rand
.
nextGaussian
())
val
x
=
Vector
((
goodX
.
elements
.
toSeq
++
noiseX
.
elements
.
toSeq
)
:
_
*
)
DataPoint
(
x
,
y
)
}
Array
.
tabulate
(
N
)(
generatePoint
)
}
val
data
=
generateData
for
(
nThreads
<-
List
(
1
,
10
))
{
//test single & multi-threaded
val
sc
=
new
SparkContext
(
"local["
+
nThreads
+
"]"
,
"test"
)
val
weights
=
Vector
.
zeros
(
2
*
D
)
val
weightDelta
=
sc
.
accumulator
(
Vector
.
zeros
(
2
*
D
))
for
(
itr
<-
1
to
ITERATIONS
)
{
val
eta
=
0.1
/
itr
val
badErrs
=
sc
.
accumulator
(
0
)
sc
.
parallelize
(
data
).
foreach
{
p
=>
{
//XXX Note the call to .value here. That is required for this to be an online gradient descent
// instead of a batch version. Should it change to .localValue, and should .value throw an error
// if you try to do this??
val
prod
=
weightDelta
.
value
.
plusDot
(
weights
,
p
.
x
)
val
trueClassProb
=
(
1
/
(
1
+
exp
(-
p
.
y
*
prod
)))
// works b/c p(-z) = 1 - p(z) (where p is the logistic function)
val
update
=
p
.
x
*
trueClassProb
*
p
.
y
*
eta
//we could also include a momentum term here if our weightDelta accumulator saved a momentum
weightDelta
.
value
+=
update
if
(
trueClassProb
<=
0.95
)
badErrs
+=
1
}
test
(
"value not readable in tasks"
)
{
import
SetAccum._
val
maxI
=
1000
for
(
nThreads
<-
List
(
1
,
10
))
{
//test single & multi-threaded
val
sc
=
new
SparkContext
(
"local["
+
nThreads
+
"]"
,
"test"
)
val
acc
:
Accumulable
[
mutable.Set
[
Any
]
,
Any
]
=
sc
.
accumulable
(
new
mutable
.
HashSet
[
Any
]())
val
d
=
sc
.
parallelize
(
1
to
maxI
)
val
thrown
=
evaluating
{
d
.
foreach
{
x
=>
acc
.
value
+=
x
}
println
(
"Iteration "
+
itr
+
" had badErrs = "
+
badErrs
.
value
)
weights
+=
weightDelta
.
value
println
(
weights
)
//TODO I should check the number of bad errors here, but for some reason spark tries to serialize the assertion ...
// val assertVal = badErrs.value
// assert (assertVal < 100)
}
}
should
produce
[
SparkException
]
println
(
thrown
)
}
}
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment