/* scala-stm - (c) 2009-2014, Stanford University, PPL */

package scala.concurrent.stm.japi

import java.util.concurrent.Callable
import java.util.{List => JList, Map => JMap, Set => JSet}

import scala.collection.JavaConverters._
import scala.concurrent.stm
import scala.concurrent.stm._
import scala.language.implicitConversions

private[japi] object STMHelpers {
  // The anonymous classes generated by Scala 2.8.2 from inside a method with
  // a type parameter are not digestible by Java, so hide them here to make
  // sure they won't be included by:
  //
  //   static import scala.concurrent.stm.japi.STM.*;
  //
  // Also, scala.Function1 is difficult to use from Java because Java won't
  // automatically wire up all of the mix-in methods.  A better option would
  // be scala.runtime.AbstractFunction1, but there is no guarantee that the
  // interface won't change (and Scala 2.8.2's version is not usable from
  // Java).  Instead, we define STM.Transformer<A> that is basically just a
  // Function1[A,A].

  implicit def callableToAtomicBlock[A <: AnyRef](f: Callable[A]): (InTxn => A) = { _ => f.call }

  implicit def transformerToFunction[A <: AnyRef](f: STM.Transformer[A]): (A => A) = { f(_) }
}

/**
 * Java-friendly API for ScalaSTM.
 * These methods can also be statically imported.
 */
object STM {
  import STMHelpers._

  /**
   * Create a Ref with an initial value. Return a `Ref.View`, which does not
   * require implicit transactions.
   * @param initialValue the initial value for the newly created `Ref.View`
   * @return a new `Ref.View`
   */
  def newRef[A](initialValue: A): Ref.View[A] = Ref(initialValue).single

  /**
   * Create an empty TMap. Return a `TMap.View`, which does not require
   * implicit transactions. See newMap for included java conversion.
   * @return a new, empty `TMap.View`
   */
  def newTMap[A, B](): TMap.View[A, B] = TMap.empty[A, B].single

  /**
   * Create an empty TMap. Return a `java.util.Map` view of this TMap.
   * @return a new, empty `TMap.View` wrapped as a `java.util.Map`.
   */
  def newMap[A, B](): JMap[A, B] = newTMap[A, B]().asJava

  /**
   * Create an empty TSet. Return a `TSet.View`, which does not require
   * implicit transactions. See newSet for included java conversion.
   * @return a new, empty `TSet.View`
   */
  def newTSet[A](): TSet.View[A] = TSet.empty[A].single

  /**
   * Create an empty TSet. Return a `java.util.Set` view of this TSet.
   * @return a new, empty `TSet.View` wrapped as a `java.util.Set`.
   */
  def newSet[A](): JSet[A] = newTSet[A]().asJava

  /**
   * Create a TArray containing `length` elements. Return a `TArray.View`,
   * which does not require implicit transactions. See newList for included
   * java conversion.
   * @param length the length of the `TArray.View` to be created
   * @return a new `TArray.View` containing `length` elements (initially null)
   */
  def newTArray[A <: AnyRef](length: Int): TArray.View[A] = TArray.ofDim[AnyRef](length).asInstanceOf[TArray[A]].single

  /**
   * Create an empty TArray. Return a `java.util.List` view of this Array.
   * @param length the length of the `TArray.View` to be created
   * @return a new, empty `TArray.View` wrapped as a `java.util.List`.
   */
  def newArrayAsList[A <: AnyRef](length: Int): JList[A] = newTArray[A](length).asJava

  /**
   * Atomic block that takes a `Runnable`.
   * @param runnable the `Runnable` to run within a transaction
   */
  def atomic(runnable: Runnable): Unit = stm.atomic { _ => runnable.run() }

  /**
   * Atomic block that takes a `Callable`.
   * @param callable the `Callable` to run within a transaction
   * @return the value returned by the `Callable`
   */
  def atomic[A <: AnyRef](callable: Callable[A]): A = stm.atomic(callable)

  /**
   * Causes the enclosing transaction to back up and wait until one
   * of the `Ref`s touched by this transaction has changed.
   * @throws IllegalStateException if not in a transaction
   */
  def retry(): Unit = Txn.findCurrent match {
    case Some(txn) => Txn.retry(txn)
    case None => throw new IllegalStateException("retry outside atomic")
  }

  /**
   * Like `retry`, but limits the total amount of blocking.  This method
   * only returns normally when the timeout has expired.
   */
  def retryFor(timeoutMillis: Long): Unit = Txn.findCurrent match {
    case Some(txn) => Txn.retryFor(timeoutMillis)(txn)
    case None => throw new IllegalStateException("retry outside atomic")
  }

  abstract class Transformer[A <: AnyRef] {
    def apply(v: A): A
  }

  /**
   * Transform the value stored by `ref` by applying the function `f`.
   * @param ref the `Ref.View` to be transformed
   * @param f the function to be applied
   */
  def transform[A <: AnyRef](ref: Ref.View[A], f: Transformer[A]): Unit = ref.transform(f)

  /**
   * Transform the value stored by `ref` by applying the function `f` and
   * return the old value.
   * @param ref the `Ref.View` to be transformed
   * @param f the function to be applied
   * @return the old value of `ref`
   */
  def getAndTransform[A <: AnyRef](ref: Ref.View[A], f: Transformer[A]): A = ref.getAndTransform(f)

  /**
   * Transform the value stored by `ref` by applying the function `f` and
   * return the new value.
   * @param ref the `Ref.View` to be transformed
   * @param f the function to be applied
   * @return the new value of `ref`
   */
  def transformAndGet[A <: AnyRef](ref: Ref.View[A], f: Transformer[A]): A = ref.transformAndGet(f)

  /**
   * Increment the `java.lang.Integer` value of a `Ref.View`.
   * @param ref the `Ref.View<Integer>` to be incremented
   * @param delta the amount to increment
   */
  def increment(ref: Ref.View[java.lang.Integer], delta: Int): Unit = ref.transform(_.intValue + delta)

  /**
   * Increment the `java.lang.Long` value of a `Ref.View`.
   * @param ref the `Ref.View<Long>` to be incremented
   * @param delta the amount to increment
   */
  def increment(ref: Ref.View[java.lang.Long], delta: Long): Unit = ref.transform(_.longValue + delta)

  private def activeTxn = Txn.findCurrent match {
    case Some(txn) => txn
    case None => throw new IllegalStateException("not in a transaction")
  }

  /**
   * Add a task to run after the current transaction has committed.
   * @param task the `Runnable` task to run after transaction commit
   * @throws IllegalStateException if called from outside a transaction
   */
  def afterCommit(task: Runnable): Unit = Txn.afterCommit(_ => task.run())(activeTxn)

  /**
   * Add a task to run after the current transaction has rolled back.
   * @param task the `Runnable` task to run after transaction rollback
   * @throws IllegalStateException if called from outside a transaction
   */
  def afterRollback(task: Runnable): Unit = Txn.afterRollback(_ => task.run())(activeTxn)

  /**
   * Add a task to run after the current transaction has either rolled back
   * or committed.
   * @param task the `Runnable` task to run after transaction completion
   * @throws IllegalStateException if called from outside a transaction
   */
  def afterCompletion(task: Runnable): Unit = Txn.afterCompletion(_ => task.run())(activeTxn)
}
