/**
 * Copyright (c) 2002, Peace Technology, Inc.
 * $Author:Roytman, Alex$
 * $Revision$
 * $Date$
 * $NoKeywords$
 */

package com.peacetech.webtools.tomcat.dbcp;

import javax.naming.Context;
import javax.naming.Name;
import javax.naming.Reference;
import javax.naming.spi.ObjectFactory;
import java.util.Hashtable;
import java.util.Timer;
import java.util.Date;
import java.util.TimerTask;
import java.util.Map;
import java.util.WeakHashMap;
import java.lang.ref.WeakReference;

public class TimerServiceImpl extends Timer implements ObjectFactory, TimerService {
  private Map tasks;

  public TimerServiceImpl() {
    super(true);
  }

  public void schedule(Runnable task, long delay) {
    super.schedule(createTask(task), delay);
  }

  public void schedule(Runnable task, Date time) {
    super.schedule(createTask(task), time);
  }

  public void schedule(Runnable task, long delay, long period) {
    super.schedule(createTask(task), delay, period);
  }

  public void schedule(Runnable task, Date firstTime, long period) {
    super.schedule(createTask(task), firstTime,  period);
  }

  public void scheduleAtFixedRate(Runnable task, long delay, long period) {
    super.scheduleAtFixedRate(createTask(task), delay, period);
  }

  public void scheduleAtFixedRate(Runnable task, Date firstTime, long period) {
    super.scheduleAtFixedRate(createTask(task), firstTime, period);
  }

  public Object getObjectInstance(Object obj, Name name, Context nameCtx,
                                  Hashtable environment) throws Exception {
    if (obj instanceof Reference) {
      Reference ref = (Reference)obj;
      if (getClass().getName().equals(ref.getClassName())) {
        return this;
      }
    }
    return null;
  }

  public synchronized void cancelTask(Object task) {
    if (task instanceof TimerTask) {
      ((TimerTask)task).cancel();
    } else {
      TimerTask t = (TimerTask)(tasks == null ? null : tasks.get(task));
      if (t != null) {
        t.cancel();
      }
    }
  }

  private synchronized TimerTask createTask(Runnable runnable) {
    if (tasks == null) {
     tasks = new WeakHashMap();
    }
    RunnableTimerTask task = new RunnableTimerTask(runnable);
    tasks.put(runnable, task);
    return task;
  }

  private static class RunnableTimerTask extends TimerTask {
    private final WeakReference runnableRef;

    public RunnableTimerTask(Runnable runnable) {
      if (runnable == null) {
        throw new NullPointerException();
      }
      this.runnableRef = new WeakReference(runnable);
    }

    public void run() {
      Runnable runnable = (Runnable)runnableRef.get();
      if (runnable != null) {
        runnable.run();
      }
    }
  }
}
